Classify Gender Using Long Short-Term Memory Networks

This example shows how to classify the gender of a speaker using deep learning. In particular, the example uses a Bidirectional Long Short-Term Memory (BiLSTM) network and Gammatone Cepstral Coefficients (gtcc), pitch, harmonic ratio, and several spectral shape descriptors.

Introduction

Gender classification based on speech signals is an essential component of many audio systems, such as automatic speech recognition, speaker recognition, and content-based multimedia indexing.

This example uses long short-term memory (LSTM) networks, a type of recurrent neural network (RNN) well-suited to study sequence and time-series data. An LSTM network can learn long-term dependencies between time steps of a sequence. An LSTM layer (lstmLayer) can look at the time sequence in the forward direction, while a bidirectional LSTM layer (bilstmLayer) can look at the time sequence in both forward and backward directions. This example uses a bidirectional LSTM layer.

This example trains the LSTM network with sequences of Gammatone Cepstrum Coefficients (gtcc), pitch estimates, harmonic ratio, and several spectral shape descriptors.

The example goes through the following steps:

  1. Create an audioDatastore that points to the audio speech files used to train the LSTM network.

  2. Remove silence and non-speech segments from the speech files using a simple thresholding technique.

  3. Extract feature sequences consisting of GTCC coefficients, pitch, harmonic ratio, and several spectral shape descriptors from the speech signals.

  4. Train the LSTM network using the feature sequences.

  5. Measure and visualize the accuracy of the classifier on the training data.

  6. Create an audioDatastore of speech files used to test the trained network.

  7. Remove non-speech segments from these files, generate feature sequences, pass them through the network, and test its accuracy by comparing the predicted and actual gender of the speakers.

To accelerate the training process, run this example on a machine with a GPU. If your machine has a GPU and Parallel Computing Toolbox™, then MATLAB® automatically uses the GPU for training; otherwise, it uses the CPU.

Examine the Dataset

This example uses the Mozilla Common Voice dataset [1]. The dataset contains 48 kHz recordings of subjects speaking short sentences. Download the dataset and untar the downloaded file. Set datafolder to the location of the data.

datafolder = PathToDatabase;

Use audioDatastore to create a datastore for the files in the cv-valid-train folder.

ads = audioDatastore(fullfile(datafolder,"cv-valid-train"));

Use readtable to read the metadata associated with the audio files. The metadata is contained in the cv_valid-train.csv file. Inspect the first few rows of the metadata.

metadata = readtable(fullfile(datafolder,"cv-valid-train.csv"));
head(metadata)
ans =

  8×8 table

                 filename                                                           text                                              up_votes    down_votes       age         gender     accent    duration
    __________________________________    ________________________________________________________________________________________    ________    __________    __________    ________    ______    ________

    'cv-valid-train/sample-000000.mp3'    'learn to recognize omens and follow them the old king had said'                               1            0         ''            ''           ''         NaN   
    'cv-valid-train/sample-000001.mp3'    'everything in the universe evolved he said'                                                   1            0         ''            ''           ''         NaN   
    'cv-valid-train/sample-000002.mp3'    'you came so that you could learn about your dreams said the old woman'                        1            0         ''            ''           ''         NaN   
    'cv-valid-train/sample-000003.mp3'    'so now i fear nothing because it was those omens that brought you to me'                      1            0         ''            ''           ''         NaN   
    'cv-valid-train/sample-000004.mp3'    'if you start your emails with greetings let me be the first to welcome you to earth'          3            2         ''            ''           ''         NaN   
    'cv-valid-train/sample-000005.mp3'    'a shepherd may like to travel but he should never forget about his sheep'                     1            0         'twenties'    'female'     'us'       NaN   
    'cv-valid-train/sample-000006.mp3'    'night fell and an assortment of fighting men and merchants entered and exited the tent'       3            0         ''            ''           ''         NaN   
    'cv-valid-train/sample-000007.mp3'    'i heard a faint movement under my feet'                                                       2            1         ''            ''           ''         NaN   

You will use data corresponding to adult speakers only. Read the gender and age variables from the metadata. Make sure the files in the metadata and the datastore are arranged in the same order.

csvFiles   = metadata.filename;
[~,csvInd] = sort(csvFiles);

% Read the gender variable from the table.
gender     = metadata.gender;
age        = metadata.age;

adsFiles   = ads.Files;
[~,adsInd] = sort(adsFiles);

% Re-arrange gender to ensure information is linked to the correct files.
gender     = gender(csvInd(adsInd));
age        = age(csvInd(adsInd));

Assign gender to the Labels property of the datastore.

ads.Labels = gender;

Use shuffle to randomize the order of the files in the datastore.

ads = shuffle(ads);

Not all files in the dataset are annotated with gender and age information. Create a subset of the datastore that only contains files where gender information is available and age is greater than 19.

maleOrfemale = categorical(ads.Labels) == "male" | categorical(ads.Labels) == "female";
isAdult      = categorical(age) ~= "" & categorical(age) ~= "teens";
ads          = subset(ads,maleOrfemale & isAdult);

Use countEachLabel to inspect the gender breakdown of the files.

countEachLabel(ads)
ans =

  2×2 table

    Label     Count
    ______    _____

    female    17747
    male      53494

You will train the deep learning network on a subset of the files. Create a datastore subset containing 3000 male speakers and 3000 female speakers.

ismale   = find(categorical(ads.Labels) == "male");
isfemale = find(categorical(ads.Labels) == "female");
numFilesPerGender = 3000;
ads = subset(ads,[ismale(1:numFilesPerGender) isfemale(1:numFilesPerGender)]);

Use shuffle again to randomize the order of the files in the datastore.

ads = shuffle(ads);

Use countEachLabel to verify the gender breakdown of the training set.

countEachLabel(ads)
ans =

  2×2 table

    Label     Count
    ______    _____

    female    3000 
    male      3000 

Isolate Speech Segments

Read the contents of an audio file using read.

[audio,info] = read(ads);
Fs           = info.SampleRate;

Plot the audio signal and then listen to it using the sound command.

timeVector = (1/Fs) * (0:numel(audio)-1);
figure
plot(timeVector,audio)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audio,Fs)

The speech signal has silence segments that do not contain useful information pertaining to the gender of the speaker. This example removes silence using a simplified version of the thresholding approach described in [2]. The steps of the silence removal algorithm are outlined below.

First, compute two features over non-overlapping frames of the audio data: The signal energy and the spectral centroid. The spectral centroid is a measure of the "center of gravity" of a signal spectrum.

Break the audio into 50-millisecond non-overlapping frames.

audio        = audio ./ max(abs(audio)); % Normalize amplitude
windowLength = 50e-3 * Fs;
segments     = buffer(audio,windowLength);

Compute the energy and spectral centroid for each frame.

win = hann(windowLength,'periodic');
signalEnergy = sum(segments.^2,1)/windowLength;
centroid = spectralCentroid(segments,Fs,'Window',win,'OverlapLength',0);

Next, set thresholds for each feature. Regions where the feature values fall below or above their respective thresholds are disregarded. In this example, the energy threshold is set to half the mean energy and the spectral centroid threshold is set to 5000 Hz.

T_E            = mean(signalEnergy)/2;
T_C            = 5000;
isSpeechRegion = (signalEnergy>=T_E) & (centroid<=T_C);

Visualize the computed energy and spectral centroid over time.

% Hold the signal energy, spectral centroid, and speech decision values for
% plotting purposes.
CC = repmat(centroid,windowLength,1);
CC = CC(:);
EE = repmat(signalEnergy,windowLength,1);
EE = EE(:);
flags2 = repmat(isSpeechRegion,windowLength,1);
flags2 = flags2(:);

figure

subplot(3,1,1)
plot(timeVector, CC(1:numel(audio)), ...
     timeVector, repmat(T_C,1,numel(timeVector)), "LineWidth",2)
xlabel("Time (s)")
ylabel("Normalized Centroid")
legend("Centroid","Threshold")
title("Spectral Centroid")
grid on

subplot(3,1,2)
plot(timeVector, EE(1:numel(audio)), ...
     timeVector, repmat(T_E,1,numel(timeVector)),"LineWidth",2)
ylabel("Normalized Energy")
legend("Energy","Threshold")
title("Window Energy")
grid on

subplot(3,1,3)
plot(timeVector, audio, ...
     timeVector,flags2(1:numel(audio)),"LineWidth",2)
ylabel("Audio")
legend("Audio","Speech Region")
title("Audio")
grid on
ylim([-1 1.1])

Extract the segments of speech from the audio. Assume speech is present for samples where energy is above its threshold and the spectral centroid is below its threshold.

% Get indices of frames where a speech-to-silence or silence-to-speech
% transition occurs.
regionStartPos = find(diff([isSpeechRegion(1)-1, isSpeechRegion]));

% Get the length of the all-silence or all-speech regions.
RegionLengths  = diff([regionStartPos, numel(isSpeechRegion)+1]);

% Get speech-only regions.
isSpeechRegion = isSpeechRegion(regionStartPos) == 1;
regionStartPos = regionStartPos(isSpeechRegion);
RegionLengths  = RegionLengths(isSpeechRegion);

% Get start and end indices for each speech region. Extend the region by 5
% windows on each side.
startIndices = zeros(1,numel(RegionLengths));
endIndices   = zeros(1,numel(RegionLengths));
for index=1:numel(RegionLengths)
   startIndices(index) = max(1, (regionStartPos(index) - 5) * windowLength + 1);
   endIndices(index)   = min(numel(audio), (regionStartPos(index) + RegionLengths(index) + 5) * windowLength);
end

Finally, merge intersecting speech segments.

activeSegment       = 1;
isSegmentsActive    = zeros(1,numel(startIndices));
isSegmentsActive(1) = 1;
for index = 2:numel(startIndices)
    if startIndices(index) <= endIndices(activeSegment)
        % Current segment intersects with previous segment
        if endIndices(index) > endIndices(activeSegment)
           endIndices(activeSegment) =  endIndices(index);
        end
    else
        % New speech segment detected
        activeSegment = index;
        isSegmentsActive(index) = 1;
    end
end
numSegments = sum(isSegmentsActive);
segments    = cell(1,numSegments);
limits      = zeros(2,numSegments);
speechSegmentsIndices  = find(isSegmentsActive);
for index = 1:length(speechSegmentsIndices)
    segments{index} = audio(startIndices(speechSegmentsIndices(index)): ...
                            endIndices(speechSegmentsIndices(index)));
    limits(:,index) = [startIndices(speechSegmentsIndices(index)) ; ...
                       endIndices(speechSegmentsIndices(index))];
end

Plot the original audio along with the detected speech segments.

figure

plot(timeVector,audio)
hold on
myLegend = cell(1,numel(segments)+1);
myLegend{1} = "Original Audio";
for index = 1:numel(segments)
    plot(timeVector(limits(1,index):limits(2,index)),segments{index});
    myLegend{index+1} = sprintf("Output Audio Segment %d",index);
end
xlabel("Time (s)")
ylabel("Audio")
grid on
legend(myLegend)

Audio Features

A speech signal is dynamic in nature and changes over time. It is assumed that speech signals are stationary on short time scales and their processing is often done in windows of 20-40 ms. For each speech segment, this example extracts audio features for 30 ms windows with 75% overlap.

win = hamming(0.03*Fs,"periodic");
overlapLength = 0.75*numel(win);
featureParams = struct("SampleRate",Fs, ...
                 "Window",win, ...
                 "OverlapLength",overlapLength);

This example trains the LSTM network using gammatone cepstrum coefficients (GTCC), the delta and delta-delta GTCC, pitch estimates, harmonic ratio, and several spectral shape descriptors:

For each speech segment, the feature vectors are concatenated into sequences with 50% overlap. Each feature vector contains 50 features. Each sequence contains 40 feature vectors. Define the sequence parameters as a stuct.

sequenceParams = struct("NumFeatures",50, ...
                 "SequenceLength",40, ...
                 "HopLength",20);

The figure provides an overview of the feature extraction used in this example.

Extract Features Using Tall Arrays

To speed up processing, extract feature sequences from the speech segments of all audio files in the datastore using tall arrays. Unlike in-memory arrays, tall arrays typically remain unevaluated until you request that the calculations be performed using the gather function. This deferred evaluation allows you to work quickly with large data sets. When you eventually request the output using gather, MATLAB® combines the queued calculations where possible and takes the minimum number of passes through the data. If you have Parallel Computing Toolbox™, you can use tall arrays in your local MATLAB® session, or on a local parallel pool. You can also run tall array calculations on a cluster if you have MATLAB® Parallel Server™ installed.

First, convert the datastore to a tall array:

T = tall(ads)
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 12).

T =

  M×1 tall cell array

    { 55920×1 double}
    {236784×1 double}
    {380784×1 double}
    {440688×1 double}
    {277104×1 double}
    {130800×1 double}
    {222960×1 double}
    {123888×1 double}
        :        :
        :        :

The display indicates that the number of rows (corresponding to the number of files in the datastore), M, is not yet known. M is a placeholder until the calculation completes.

Extract the speech segments from the tall table. This action creates a new tall array variable to use in subsequent calculations. The function HelperSegmentSpeech performs the steps already highlighted in the Isolate Speech Segments section. The cellfun command applies HelperSegmentSpeech to the contents of each audio file in the datastore.

segments = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false);

Extract feature sequences from the speech segments. The function HelperGetFeatureSequences performs the steps already highlighted in the Audio Features section.

FeatureSequences = cellfun(@(x)HelperGetFeatureSequences(x,featureParams,sequenceParams),...
                           segments,"UniformOutput",false);

Use gather to evaluate FeatureSequences:

FeatureSequences = gather(FeatureSequences);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 12 min 2 sec
Evaluation completed in 12 min 2 sec

In classification applications, it is good practice to normalize all features to have zero mean and unity standard deviation.

Compute the mean and standard deviation for each coefficient, and use them to normalize the data.

featuresMatrix = cat(3,FeatureSequences{:});

sequencesMeans = zeros(1,sequenceParams.NumFeatures);
sequenceStds   = zeros(1,sequenceParams.NumFeatures);

for index = 1:sequenceParams.NumFeatures
    localFeatures             = featuresMatrix(:,index,:);
    sequencesMeans(index)     = mean(localFeatures(:));
    sequenceStds(index)       = std(localFeatures(:));
    featuresMatrix(:,index,:) = (localFeatures - sequencesMeans(index))/sequenceStds(index);
end

Create a cell array, features, containing the sequence predictors. Each entry in features is a D-by-S matrix, where D is the number of values per time step (the number of features), and S is the length of the sequence (in this example, 40).

features = cell(1,size(featuresMatrix,3));
for index = 1:size(featuresMatrix,3)
    features{index} = featuresMatrix(:,:,index).';
end

Create a cell array, gender, for the expected gender associated with each training sequence.

numSequences = cellfun(@(x)size(x,3), FeatureSequences);
mylabels     = ads.Labels;
gender       = cell(sum(numSequences),1);
count        = 1;
for index1 = 1:numel(numSequences)
    for index2 = 1:numSequences(index1)
        gender{count} = mylabels{index1};
        count = count + 1;
    end
end

Define the LSTM Network Architecture

LSTM networks can learn long-term dependencies between time steps of sequence data. This example uses the bidirectional LSTM layer bilstmLayer to look at the sequence in both forward and backward directions.

Specify the input size to be sequences of size NumFeatures. Specify a hidden bidirectional LSTM layer with an output size of 100 and output a sequence. This command instructs the bidirectional LSTM layer to map the input time series into 100 features that are passed to the next layer. Then, specify a bidirectional LSTM layer with an output size of 100 and output the last element of the sequence. This command instructs the bidirectional LSTM layer to map its input into 100 features and then prepares the output for the fully connected layer. Finally, specify two classes by including a fully connected layer of size 2, followed by a softmax layer and a classification layer.

layers = [ ...
    sequenceInputLayer(sequenceParams.NumFeatures)
    bilstmLayer(100,"OutputMode","sequence")
    bilstmLayer(100,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Next, specify the training options for the classifier. Set MaxEpochs to 10 so that the network makes 10 passes through the training data. Set MiniBatchSize of 128 so that the network looks at 128 training signals at a time. Specify Plots as "training-progress" to generate plots that show the training progress as the number of iterations increases. Set Verbose to false to disable printing the table output that corresponds to the data shown in the plot. Specify Shuffle as "every-epoch" to shuffle the training sequence at the beginning of each epoch. Specify LearnRateSchedule to "piecewise" to decrease the learning rate by a specified factor (0.1) every time a certain number of epochs (5) has passed.

This example uses the adaptive moment estimation (ADAM) solver. ADAM performs better with recurrent neural networks (RNNs) like LSTMs than the default stochastic gradient descent with momentum (SGDM) solver.

options = trainingOptions("adam", ...
    "MaxEpochs",10, ...
    "MiniBatchSize",128, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Train the LSTM Network

Train the LSTM network with the specified training options and layer architecture using trainNetwork. Because the training set is large, the training process can take several minutes.

net = trainNetwork(features,categorical(gender),layers,options);

The top subplot of the training-progress plot represents the training accuracy, which is the classification accuracy on each mini-batch. When training progresses successfully, this value typically increases towards 100%. The bottom subplot displays the training loss, which is the cross-entropy loss on each mini-batch. When training progresses successfully, this value typically decreases towards zero.

If the training is not converging, the plots might oscillate between values without trending in a certain upward or downward direction. This oscillation means that the training accuracy is not improving and the training loss is not decreasing. This situation can occur at the start of training, or after some preliminary improvement in training accuracy. In many cases, changing the training options can help the network achieve convergence. Decreasing MiniBatchSize or decreasing InitialLearnRate might result in a longer training time, but it can help the network learn better.

At the end of 10 epochs, the training accuracy exceeds 95%.

Visualize the Training Accuracy

Calculate the training accuracy, which represents the accuracy of the classifier on the signals on which it was trained. First, classify the training data.

trainPred = classify(net,features);

Plot the confusion matrix. Display the precision and recall for the two classes by using column and row summaries.

figure;
cm = confusionchart(categorical(gender),trainPred,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

The example generated multiple sequences from each training speech file. Higher accuracy can be achieved by considering the output class of all sequences corresponding to the same file, and applying a "majority-rule" decision, where the class that occurred most frequently is selected

Predict the gender from each training file by considering the output classes of all sequences generated from the same file:

numFiles        = numel(numSequences);
actualGender    = categorical(ads.Labels);
predictedGender = actualGender;
counter         = 1;
for index = 1:numFiles
    % Get output classes from sequences corresponding to this file:
    predictions = trainPred(counter: counter + numSequences(index) - 1);
    % Set predicted gender to the most frequently predicted class
    predictedGender(index) = mode(predictions);
    counter = counter + numSequences(index);
end

Visualize the confusion matrix on the majority-rule predictions.

figure
cm = confusionchart(actualGender,predictedGender,'title','Training Accuracy - Majority Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Visualize the Testing Accuracy

Measure the classifier's accuracy on testing data containing speech files that were not used when training the LSTM network.

Use audioDatastore to create a datastore for the files in the cv-valid-test folder.

ads = audioDatastore(fullfile(datafolder,"cv-valid-test"));

Use the readtable function to read the metadata associated with these files.

metadata = readtable(fullfile(datafolder,"cv-valid-test.csv"));

Read the gender and age variables from the metadata. Make sure the files in the metadata and the datastore are arranged in the same order.

csvFiles    = metadata.filename;
[~,csvInd]  = sort(csvFiles);

% Read the gender variable from the table.
gender     = metadata.gender;
age        = metadata.age;
adsFiles   = ads.Files;
[~,adsInd] = sort(adsFiles);

% Re-arrange gender to ensure information is linked to the correct files.
gender     = gender(csvInd(adsInd));
age        = age(csvInd(adsInd));

Assign gender to the Labels property of the datastore

ads.Labels = gender;

Create a datastore subset comprising only the files corresponding to adults where gender information is available.

maleOrfemale = categorical(ads.Labels) == "male" | categorical(ads.Labels) == "female";
isAdult      = categorical(age) ~= "" & categorical(age) ~= "teens";
ads          = subset(ads,maleOrfemale & isAdult);

Use countEachLabel to inspect the gender breakdown of the files.

countEachLabel(ads)
ans =

  2×2 table

    Label     Count
    ______    _____

    female     358 
    male      1055 

Remove silence and extract features from the test data. Use gather to evaulate FeatureSequences.

T                = tall(ads);
segments         = cellfun(@(x)HelperSegmentSpeech(x,Fs), T,"UniformOutput",false);
FeatureSequences = cellfun(@(x)HelperGetFeatureSequences(x,featureParams,sequenceParams),...
                           segments,"UniformOutput",false);

FeatureSequences = gather(FeatureSequences);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 12).
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 3 min 2 sec
Evaluation completed in 3 min 2 sec

Normalize the feature sequence by the mean and standard deviations computed during the training stage.

featuresMatrix = cat(3,FeatureSequences{:});
for index=1:sequenceParams.NumFeatures
    XI                        = featuresMatrix(:,index,:);
    featuresMatrix(:,index,:) = (XI - sequencesMeans(index))/sequenceStds(index);
end

Create a cell array containing the sequence predictors.

features = cell(1,size(featuresMatrix,3));
for index = 1:size(featuresMatrix,3)
    features{index} = featuresMatrix(:,:,index).';
end

% Create a cell array, |gender|, for the expected gender associated with
% each training sequence.
numSequences = cellfun(@(x)size(x,3),FeatureSequences);
mylabels     = ads.Labels;
gender       = cell(sum(numSequences),1);
count        = 1;
for index1 =1:numel(numSequences)
    for index2 = 1:numSequences(index1)
        gender{count} = mylabels{index1};
        count = count + 1;
    end
end

Predict gender for each sequence.

testPred = classify(net,features);

Plot the confusion matrix.

figure
cm = confusionchart(categorical(gender),testPred,'title','Testing Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Predict the gender from each training file by considering the output classes of all sequences generated from the same file:

numFiles        = numel(numSequences);
actualGender    = categorical(ads.Labels);
predictedGender = actualGender;
counter         = 1;
for index = 1:numFiles
    % Get output classes from sequences corresponding to this file:
    predictions = testPred(counter: counter + numSequences(index) - 1);

    % Set predicted gender to the most frequently predicted class
    predictedGender(index) = mode(predictions);

    counter = counter + numSequences(index);
end

Visualize the confusion matrix on the majority-rule predictions.

figure
cm = confusionchart(actualGender,predictedGender,'title','Testing Accuracy - Majority Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

References

[1] https://www.kaggle.com/mozillaorg/common-voice

[2] Introduction to Audio Analysis: A MATLAB Approach, Giannakopoulos and Pikrakis, Academic Press.