Main Content

Train 3-D Sound Event Localization and Detection (SELD) Using Deep Learning

In this example, you train a deep learning model to perform sound localization and event detection from ambisonic data. The model consists of two independently trained convolutional recurrent neural networks (CRNN) [1]: one for sound event detection (SED), and one for direction of arrival (DOA) estimation. To explore the models trained in this example, see 3-D Sound Event Localization and Detection Using Trained Recurrent Convolutional Neural Network.

Introduction

Ambisonics is a popular 3-D sound format that has shown promise in tasks like sound source localization, speech enhancement, and source separation. Ambisonics is a full sphere surround sound format that contains a speaker-independent sound field representation (B-format). First order B-format ambisonic recordings contain components that correspond to the sound pressure captured by an omnidirectional microphone (W) and sound pressure gradients X, Y, and Z that correspond to front/back, left/right, and up/down captured by figure-of-eight capsules oriented along the three spatial axes. 3-D SELD has applications in virtual reality, robotics, smart homes, and defense.

You will train two separate models for the sound event detection task and the localization task. Both models are based on the convolutional recurrent neural network architecture described in [1]. The sound event detection task is formulated as a classification task. The sound event localization task estimates Cartesian coordinates of the sound source and is formulated as a regression task. You use the L3DAS21 data set [2] to train and validate the networks. To explore the models trained in this example, see 3-D Sound Event Localization and Detection Using Trained Recurrent Convolutional Neural Network.

Download and Prepare Data

This example uses a subset of the L3DAS21 Task 2 challenge data set [2]. The data set contains multiple-source and multiple-perspective (MSMP) B-format ambisonic audio recordings collected at a sampling rate of 32 kHz. The train and validation splits are provided with the data set. Each recording is one minute long and contains a simulated 3-D audio environment in which up to 3 simultaneous acoustic events may be active at the same time. In this example, you only use the data that contains non-overlapping sounds. The sound events belong to 14 sound classes. The labels are provided as csv files that contain the sound class, the Cartesian coordinates of the sound source, and the onset and offset time stamps.

Download the dataset.

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","L3DAS21_ov1.zip");
dataFolder = tempdir;
unzip(downloadFolder,dataFolder)
dataset = fullfile(dataFolder,"L3DAS21_ov1");

Optionally Reduce Data Set

To train the networks with the entire data set and achieve a reasonable performance, set speedupExample to false. To run this example quickly, set speedupExample to true.

speedupExample = false;

Create Datastores

Create audioDatastore objects to ingest the data. Each data point in the data set consists of two B-format ambisonic recordings that correspond to the two microphones (A and B). For each data folder (train and validation), use subset to create two subsets corresponding to the two microphones.

adsTrain = audioDatastore(fullfile(dataset,"train","data"));
adsTrainA = subset(adsTrain,cellfun(@(c)endsWith(c,"A.wav"),adsTrain.Files));
adsTrainB = subset(adsTrain,cellfun(@(c)endsWith(c,"B.wav"),adsTrain.Files));

adsValidation = audioDatastore(fullfile(dataset,"validation","data"));
adsValidationA = subset(adsValidation,cellfun(@(c)endsWith(c,"A.wav"),adsValidation.Files));
adsValidationB = subset(adsValidation,cellfun(@(c)endsWith(c,"B.wav"),adsValidation.Files));

Reduce the data set if requested.

if speedupExample
    adsTrainA = subset(adsTrainA,1:2);
    adsTrainB = subset(adsTrainB,1:2);
end

Inspect Data

Preview the ambisonic recordings and plot the data.

micA = preview(adsTrainA);
micB = preview(adsTrainB);

tiledlayout(4,2,TileSpacing="tight")

nexttile
plot(micA(:,1))
title("Microphone A")
ylabel("W")

nexttile
plot(micB(:,1))
title("Microphone B")

nexttile
plot(micA(:,2))
ylabel("X")

nexttile
plot(micB(:,2))

nexttile
plot(micA(:,3))
ylabel("Y")

nexttile
plot(micB(:,3))

nexttile
plot(micB(:,4))
ylabel("Z")

nexttile
plot(micB(:,4))

Listen to a section of the data.

microphone = 1;
channel = 1;
duration = 10;
fs = 32e3; % Known sampling rate of data.

s = [micA,micB];
data = s(1:round(duration*fs),channel + (microphone-1)*4);
sound(data,fs)

Create Targets

Each data point in the data set has a corresponding CSV file containing the sound event class, the start and end times of the sound, and the location of the sound. Create a container to map between the sound classes and integers.

keySet = ["Chink_and_clink","Computer_keyboard","Cupboard_open_or_close","Drawer_open_or_close", ...
    "Female_speech_and_woman_speaking","Finger_snapping","Keys_jangling","Knock","Laughter", ...
    "Male_speech_and_man_speaking","Printer","Scissors","Telephone","Writing"];
valueSet = {1,2,3,4,5,6,7,8,9,10,11,12,13,14};
params.SoundClasses = containers.Map(keySet,valueSet);

Create a tabularTextDatastore to ingest the train file labels. Make sure the label files are in the same order as the data files. Preview a label file from the datastore.

[folder,fn] = fileparts(adsTrainA.Files);
targetPath = fullfile(strrep(folder,filesep+"data",filesep+"labels"),"label_" + strrep(fn,"_A","") + ".csv");
ttdsTrain = tabularTextDatastore(targetPath);

labelTable = preview(ttdsTrain)
labelTable=8×7 table
    File     Start      End                     Class                     X       Y       Z  
    ____    _______    ______    ____________________________________    ____    ____    ____

     0      0.54784    9.6651    {'Writing'                         }     0.5    -1.5     0.3
     0       11.521    12.534    {'Finger_snapping'                 }    0.75    1.25      -1
     0       14.255    16.064    {'Keys_jangling'                   }     0.5    -1.5     0.3
     0       17.728    18.878    {'Chink_and_clink'                 }     0.5       1       0
     0        19.95      20.4    {'Printer'                         }    -1.5    -1.5    -0.6
     0       20.994    23.477    {'Cupboard_open_or_close'          }    -0.5    0.75       0
     0       25.032    25.723    {'Chink_and_clink'                 }      -2    -0.5    -0.3
     0       26.547    27.491    {'Female_speech_and_woman_speaking'}       1    -1.5       0

The labels in the dataset are provided with time stamps in seconds. To create targets and train a network, you need to map the time stamps to frames. The total duration of each file is 60 seconds. You will divide each file into 600 frames for the target, meaning the model will make a prediction every 0.1 seconds.

params.Targets.TotalDuration = 60;
params.Targets.NumFrames = 600;

SED Targets

The supporting function, extractSEDTargets, uses the label data to create an SED target. The target is a one-hot encoded matrix of size numframes-by-numclasses. Frames with no sounds present are encoded as all-zero vectors.

SEDTargets = extractSEDTargets(labelTable,params);

[numframes,numclasses] = size(SEDTargets{1})
numframes = 600
numclasses = 14

Extract SED targets from the train and validation sets.

dsTTrain = transform(ttdsTrain,@(x)extractSEDTargets(x,params));
sedTTrain = readall(dsTTrain);

[folder,fn] = fileparts(adsValidationA.Files);
targetPath = fullfile(strrep(folder,filesep+"data",filesep+"labels"),"label_" + strrep(fn,"_A","") + ".csv");

ttdsValidation = tabularTextDatastore(targetPath);
dsTValidation = transform(ttdsValidation,@(x)extractSEDTargets(x,params));
sedTValidation = readall(dsTValidation);

DOA Targets

The supporting function, extractDOATargets, uses the label data to create a DOA target. The target is a matrix of size numframes-by-numaxis. The axis values correspond to the sound source location in 3-D space. Frames with no sounds present are encoded as all-zero vectors.

First, define a parameter to scale the target axis values so that they are between -1 and 1. This scaling is necessary because the DOA network you define later uses tanh activation as its final layer.

params.DOA.ScaleFactor = 2;
DOATargets = extractDOATargets(labelTable,params);

[numframes,numaxis] = size(DOATargets{1})
numframes = 600
numaxis = 3

Extract DOA targets from the train and validation sets.

dsTTrain = transform(ttdsTrain,@(x)extractDOATargets(x,params));
doaTTrain = readall(dsTTrain);

[folder,fn] = fileparts(adsValidationA.Files);
targetPath = fullfile(strrep(folder,filesep+"data",filesep+"labels"),"label_" + strrep(fn,"_A","") + ".csv");

ttdsValidation = tabularTextDatastore(targetPath);
dsTValidation = transform(ttdsValidation,@(x)extractDOATargets(x,params));
doaTValidation = readall(dsTValidation);

Sound Event Detection (SED)

Feature Extraction

The sound event detection model uses log-magnitude short-time Fourier transforms (STFT) as predictors to the system. Specify a 512-point periodic Hamming window and a hop length of 400 samples.

params.SED.SampleRate = 32e3;
params.SED.HopLength = 400;
params.SED.Window = hamming(512,"periodic");

The supporting function, extractSTFT, takes a cell array of microphone readings and extracts the half-sided centered log-magnitude STFTs. The STFT features corresponding to both microphones are stacked along the third dimension.

stftFeats = extractSTFT({micA,micB},params);
[numfeaturesSED,numframesSED,numchannelsSED] = size(stftFeats)
numfeaturesSED = 256
numframesSED = 4800
numchannelsSED = 8

Plot the STFT features of one channel.

channel = 7;

figure
imagesc(stftFeats(:,:,channel))
colorbar
xlabel("Frame")
ylabel("Frequency (bin)")
set(gca,YDir="normal")

Extract features from the entire train and validation sets. First, combine the datastores corresponding to microphones A and B. Then, define a transform on the datastore so that reading from it returns the STFT. If you have Parallel Computing Toolbox™, you can speed up processing using the UseParallel flag of readall.

pFlag = ~isempty(ver("parallel")) && ~speedupExample;

trainDS = combine(adsTrainA,adsTrainB);
trainDS_T = transform(trainDS,@(x){extractSTFT(x,params)},IncludeInfo=false);
XTrain = readall(trainDS_T,UseParallel=pFlag);
valDS = combine(adsValidationA,adsValidationB);
valDS_T = transform(valDS,@(x){extractSTFT(x,params)},IncludeInfo=false);
XValidation = readall(valDS_T,UseParallel=pFlag);

Combine the predictor arrays with the previously computed SED target arrays.

trainSedDS = combine(arrayDatastore(XTrain,OutputType="same"),arrayDatastore(sedTTrain,OutputType="same"));
valSedDS = combine(arrayDatastore(XValidation,OutputType="same"),arrayDatastore(sedTValidation,OutputType="same"));

Training Options

Define training parameters for Adam optimization.

trainOptionsSED = struct( ...
    MaxEpochs=300, ...
    MiniBatchSize=4, ...
    InitialLearnRate=1e-5, ...
    GradientDecayFactor=0.01, ...
    SquaredGradientDecayFactor=0.0, ...
    ValidationPatience=25, ...
    LearnRateDropPeriod=100, ...
    LearnRateDropFactor=1);

if speedupExample
    trainOptionsSED.MaxEpochs = 1;
end

Create minibatchqueue (Deep Learning Toolbox) objects to read mini-batches from the train and validation datastores.

trainSEDmbq = minibatchqueue(trainSedDS, ...
    MiniBatchSize=trainOptionsSED.MiniBatchSize, ...
    OutputAsDlarray=[1,1], ...
    MiniBatchFormat=["SSCB","TCB"], ...
    OutputEnvironment=["auto","auto"]);

validationSEDmbq = minibatchqueue(valSedDS, ...
    MiniBatchSize=trainOptionsSED.MiniBatchSize, ...
    OutputAsDlarray=[1,1], ...
    MiniBatchFormat=["SSCB","TCB"], ...
    OutputEnvironment=["auto","auto"]);

Define Sound Event Detection (SED) Network

The network is implemented in two stages - Convolutional Neural Network (CNN) and Gated Recurrent Network (GRU). You will use a custom reshaping layer to recast the output of the CNN model into a sequence and pass that as the input to the RNN model. The custom reshaping layer is placed in your current folder when you open this example. The final output layer uses sigmoid activation.

Define the CNN layers for the SED model.

seldnetCNNLayers = [
    imageInputLayer([numfeaturesSED,numframesSED,numchannelsSED],Normalization="none",Name="input")

    convolution2dLayer([3,3],64,Padding="same",Name="conv1")
    batchNormalizationLayer(Name="batchnorm1")
    reluLayer(Name="relu1")
    maxPooling2dLayer([8,2],Stride=[8,2],Padding="same",Name="maxpool1")

    convolution2dLayer([3,3],128,Padding="same",Name="conv2")
    batchNormalizationLayer(Name="batchnorm2")
    reluLayer(Name="relu2")
    maxPooling2dLayer([8,2],Stride=[8,2],Padding="same",Name="maxpool2")

    convolution2dLayer([3,3],256,Padding="same",Name="conv3")
    batchNormalizationLayer(Name="batchnorm3")
    reluLayer(Name="relu3")
    maxPooling2dLayer([2,2],Stride=[2,2],Padding="same",Name="maxpool3")

    convolution2dLayer([3,3],512,Padding="same",Name="conv4")
    batchNormalizationLayer(Name="batchnorm4")
    reluLayer(Name="relu4")
    maxPooling2dLayer([1,1],Stride=[1,1],Padding="same",Name="maxpool4")

    reshapeLayer("reshape")
    ];
netCNN = dlnetwork(seldnetCNNLayers);

Define the RNN layers for the SED model.

seldnetGRULayers = [
    sequenceInputLayer(1024,Name="sequenceInputLayer")

    bigruLayer(1024,256,Name="gru1")
    bigruLayer(512,256,Name="gru2")
    bigruLayer(512,256,Name="gru3")

    fullyConnectedLayer(1024,Name="fc1")
    reluLayer(Name="relu1")
    fullyConnectedLayer(1024,Name="fc2")
    reluLayer(Name="relu2")
    fullyConnectedLayer(1024,Name="fc3")
    reluLayer(Name="relu3")

    fullyConnectedLayer(params.SoundClasses.Count,Name="fc4")
    sigmoidLayer(Name="output")
    ];

netRNN = dlnetwork(seldnetGRULayers);

Create a struct to contain both the CNN and RNN sections of the full model.

sedModel.CNN = netCNN;
sedModel.RNN = netRNN;

Train SED Network

Initialize variables to track the progress of the training.

iteration = 0;
averageGrad = [];
averageSqGrad = [];
epoch = 0;
bestLoss = Inf;
badEpochs = 0;
learnRate = trainOptionsSED.InitialLearnRate;

To display training progress, initialize the supporting object progresPlotterSELD. The supporting object, progressPlotterSELD, is placed in your current folder when you open this example.

pp = progressPlotterSELD();

Run the training loop.

rng(0)
while epoch < trainOptionsSED.MaxEpochs && badEpochs < trainOptionsSED.ValidationPatience
    
    epoch = epoch + 1;

    % Shuffle mini-batch queue.
    shuffle(trainSEDmbq)

    while hasdata(trainSEDmbq)

        % Update iteration counter.
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(trainSEDmbq);

        % Evaluate the model gradients and loss using dlfeval and the modelLoss function.
        [loss,grad,state] = dlfeval(@modelLoss,sedModel,X,T);
        loss = loss/size(T,2);

        % Update state.
        sedModel.CNN.State = state.CNN;
        sedModel.RNN.State = state.RNN;

        % Update the network parameters using the Adam optimizer.
        [sedModel,averageGrad,averageSqGrad] = adamupdate(sedModel,grad,averageGrad, ...
            averageSqGrad,iteration,learnRate,trainOptionsSED.GradientDecayFactor,trainOptionsSED.SquaredGradientDecayFactor);

        % Update the training progress plot.
        updateTrainingProgress(pp,Epoch=epoch,LearnRate=learnRate,Iteration=iteration,Loss=loss);
    end

    % Perform validation after each epoch.
    loss = predictBatch(sedModel,validationSEDmbq);

    % Update the training progress plot with validation results.
    updateValidation(pp,Loss=loss,Iteration=iteration)

    % Create a checkpoint if the validation loss improved. If validation
    % loss did not improve, add to the number of bad epochs.
    if loss < bestLoss
        bestLoss = loss;
        badEpochs = 0;
        fileName = "SED-BestModel";
        save(fileName,"sedModel");
    else
        badEpochs = badEpochs + 1;
    end

    % Update learn rate
    if rem(epoch,trainOptionsSED.LearnRateDropPeriod)==0
        learnRate = learnRate*trainOptionsSED.LearnRateDropFactor;
    end

end

Direction of Arrival (DOA)

Feature Extraction

The direction of arrival estimation model uses generalized cross correlation phase transform (GCC-PHAT) as predictors to the system. Specify a 1024-point Hann window, a hop length of 400 samples, and the number of bands as 96.

params.DOA.SampleRate = 32e3;
params.DOA.Window = hann(1024);
params.DOA.NumBands = 96;
params.DOA.HopLength = 400;

Extract the GCC-PHAT features used as input predictors to the sound localization network. The GCC-PHAT algorithm measures the cross correlation between each pair of channels. The input signals have a total of 8 channels, so the output has a total of 28 measurements.

gccPhatFeats = extractGCCPHAT({micA,micB},params);
[numfeaturesDOA,timestepsDOA,numchannelsDOA] = size(gccPhatFeats)
numfeaturesDOA = 96
timestepsDOA = 4800
numchannelsDOA = 28

Plot the GCC-PHAT features of a channel pair.

channelpair = 1;

figure
imagesc(gccPhatFeats(:,:,channelpair))
colorbar
xlabel("Frame")
ylabel("Band")
set(gca,YDir="normal")

Extract features from the entire train and validation sets. If you have Parallel Computing Toolbox™, you can speed up processing using the UseParallel flag of readall.

pFlag = ~isempty(ver("parallel")) && ~speedupExample;

trainDS = combine(adsTrainA,adsTrainB);
trainDS_T = transform(trainDS,@(x){extractGCCPHAT(x,params)},IncludeInfo=false);
XTrain = readall(trainDS_T,UseParallel=pFlag);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).
valDS = combine(adsValidationA,adsValidationB);
valDS_T = transform(valDS,@(x){extractGCCPHAT(x,params)},IncludeInfo=false);
XValidation = readall(valDS_T,UseParallel=pFlag);

Combine the predictor arrays with the previously compute DOA target arrays.

trainDOA = combine(arrayDatastore(XTrain,OutputType="same"),arrayDatastore(doaTTrain,OutputType="same"));
validationDOA = combine(arrayDatastore(XValidation,OutputType="same"),arrayDatastore(doaTValidation,OutputType="same"));

Training Options

Use the same train options you defined when training the SED network.

trainOptionsDOA = trainOptionsSED;

Create mini-batch queues for the train and validation sets.

trainDOAmbq = minibatchqueue(trainDOA, ...
    MiniBatchSize=trainOptionsDOA.MiniBatchSize, ...
    OutputAsDlarray=[1,1], ...
    MiniBatchFormat=["SSCB","TCB"], ...
    OutputEnvironment=["auto","auto"]);
validationDOAmbq = minibatchqueue(validationDOA, ...
    MiniBatchSize=trainOptionsDOA.MiniBatchSize, ...
    OutputAsDlarray=[1,1], ...
    MiniBatchFormat=["SSCB","TCB"], ...
    OutputEnvironment=["auto","auto"]);

Define Direction of Arrival (DOA) Network

The DOA network is very similar to the SED network defined earlier. The key differences are the size of the input layer and the final activation layer.

Update the SELDnet architecture used for the SED network for use with DOA estimation.

seldnetCNNLayers(1) = imageInputLayer([numfeaturesDOA,timestepsDOA,numchannelsDOA],Normalization="none",Name="input");
seldnetCNNLayers(5) = maxPooling2dLayer([3,2],Stride=[3,2],Padding="same",Name="maxpool1");
netCNN = dlnetwork(layerGraph(seldnetCNNLayers));

seldnetGRULayers(11) = fullyConnectedLayer(3,Name="fc4");
seldnetGRULayers(12) = tanhLayer(Name="output");
netRNN = dlnetwork(layerGraph(seldnetGRULayers));

Create a struct to contain both the CNN and RNN sections of the full model.

doaModel.CNN = netCNN;
doaModel.RNN = netRNN;

Train DOA Network

Initialize variables used in the training loop.

iteration = 0;
averageGrad = [];
averageSqGrad = [];
epoch = 0;
bestLoss = Inf;
badEpochs = 0;
learnRate = trainOptionsDOA.InitialLearnRate;

To display training progress, initialize the supporting object progressPlotterSELD. The supporting object, progressPlotterSELD, is placed in your current folder when you open this example.

pp = progressPlotterSELD();

Run the training loop.

rng(0)
while epoch < trainOptionsDOA.MaxEpochs && badEpochs < trainOptionsDOA.ValidationPatience
    
    epoch = epoch + 1;

    % Shuffle mini-batch queue.
    shuffle(trainDOAmbq)

    while hasdata(trainDOAmbq)

        % Update iteration counter.
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(trainDOAmbq);

        % Evaluate the model gradients and loss using dlfeval and the modelLoss function.
        [loss,grad,state] = dlfeval(@modelLoss,doaModel,X,T);
        loss = loss/size(T,2);

        % Update state.
        doaModel.CNN.State = state.CNN;
        doModel.RNN.State = state.RNN;

        % Update the network parameters using the Adam optimizer.
        [doaModel,averageGrad,averageSqGrad] = adamupdate(doaModel,grad,averageGrad, ...
            averageSqGrad,iteration,learnRate,trainOptionsDOA.GradientDecayFactor,trainOptionsDOA.SquaredGradientDecayFactor);

        % Update the training progress plot
        updateTrainingProgress(pp,Epoch=epoch,LearnRate=learnRate,Iteration=iteration,Loss=loss);
    end

    % Perform validation after each epoch
    loss = predictBatch(doaModel,validationDOAmbq);

    % Update the training progress plot with validation results.
    updateValidation(pp,Loss=loss,Iteration=iteration)

    % Create a checkpoint if the validation loss improved. If validation
    % loss did not improve, add to the number of bad epochs.
    if loss < bestLoss
        bestLoss = loss;
        badEpochs = 0;
        fileName = "DOA-BestModel";
        save(fileName,"doaModel");
    else
        badEpochs = badEpochs + 1;
    end

    % Update learn rate
    if rem(epoch,trainOptionsDOA.LearnRateDropPeriod)==0
        learnRate = learnRate*trainOptionsDOA.LearnRateDropFactor;
    end
end

Evaluate System Performance

To evaluate your system's performance, use the location-sensitive detection error defined in [4]. Load the best-performing models.

sedModel = importdata("SED-BestModel.mat");
doaModel = importdata("DOA-BestModel.mat");

Location-sensitive detection is a joint metric that evaluates the results of both sound event detection and sound event localization tasks. In this type of evaluation, a true positive only occurs when the predicted label is correct, and the predicted location is within a predefined threshold of the true location. A threshold of 0.2 is used in this example which is about ~3% of the maximum possible error. To determine regions of silence in the prediction, set a confidence threshold on SED decisions. If the SED predictions are below that threshold, the frame is considered silence.

params.SpatialThreshold = 0.2;
params.SilenceThreshold = 0.1;

Compute the metrics for the validation data set using the computeMetrics supporting function.

results = computeMetrics(sedModel,doaModel,validationSEDmbq,validationDOAmbq,params);
results
results = struct with fields:
    precision: 0.4246
       recall: 0.4275
      f1Score: 0.4261
       avgErr: 0.1861

The computeMetrics supporting function can optionally smooth the decisions over time before evaluating the system. This option requires the Statistics and Machine Learning Toolbox™. Evaluate the system again, this time including the smoothing.

[results,cm] = computeMetrics(sedModel,doaModel,validationSEDmbq,validationDOAmbq,params,ApplySmoothing=true);
results
results = struct with fields:
    precision: 0.5077
       recall: 0.5084
      f1Score: 0.5080
       avgErr: 0.1659

You can inspect the confusion matrix for SED predictions to get more insights on the prediction errors. The confusion matrix is only calculated over regions where there is an active sound source.

figure(Position=[100 100 800 800]);
confusionchart(cm,keys(params.SoundClasses))

Conclusion

For next steps, you can download and try out the pretrained models from this example in this second example showing inference: 3-D Sound Event Localization and Detection Using Trained Recurrent Convolutional Neural Network.

References

[1] Sharath Adavanne, Archontis Politis, Joonas Nikunen, and Tuomas Virtanen, "Sound event localization and detection of overlapping sources using convolutional recurrent neural networks," IEEE J. Sel. Top. Signal Process., vol. 13, no. 1, pp. 34–48, 2019.

[2] Eric Guizzo, Riccardo F. Gramaccioni, Saeid Jamili, Christian Marinoni, Edoardo Massaro, Claudia Medaglia, Giuseppe Nachira, Leonardo Nucciarelli, Ludovica Paglialunga, Marco Pennese, Sveva Pepe, Enrico Rocchi, Aurelio Uncini, and Danilo Comminiello "L3DAS21 Challenge: Machine Learning for 3D Audio Signal Processing," 2021.

[3] Yin Cao, Qiuqiang Kong, Turab Iqbal, Fengyan An, Wenwu Wang, and Mark D. Plumbley, "Polyphonic sound event detection and localization using a two-stage strategy," arXiv preprint: arXiv:1905.00268v4, 2019.

[4] Mesaros, Annamaria, Sharath Adavanne, Archontis Politis, Toni Heittola, and Tuomas Virtanen. "Joint Measurement of Localization and Detection of Sound Events." 2019 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA), 2019. https://doi.org/10.1109/waspaa.2019.8937220.

Supporting Functions

Extract Direction of Arrival (DOA) Targets

function T = extractDOATargets(csvFile,params)
%EXTRACTDOATARGETS Extract direction of arrival (DOA) targets
% T = extractDOATargets(fileName,params) parses the CSV file
% fileName and returns a matrix, T. The target matrix is an N-by-3
% matrix, where N corresponds to the number of frames and 3 corresponds to
% the 3 axes describing location in 3-D space.

% Preallocate target matrix. A frame of all zeros corresponds to no sound
% source.
T = zeros(params.Targets.NumFrames,3);

% Quantize the time stamps for sound sources into frames.
startendTime = [csvFile.Start,csvFile.End];
startendFrame = time2frame(startendTime,params.Targets.TotalDuration,params.Targets.NumFrames);

% For each sound source, fill the target matrix sound source location for
% the appropriate number of frames.
for ii = 1:size(startendFrame,1)
    idx = startendFrame(ii,1):startendFrame(ii,2)-1;
    T(idx,:) = repmat([csvFile.X(ii),csvFile.Y(ii),csvFile.Z(ii)],numel(idx),1);
end

% Scale the target so that it is between -1 and 1 (the bounds of the tanh
% activation layer). Wrap the target in a cell array for convenient batch
% processing.
T = {T/params.DOA.ScaleFactor};
end

Extract Sound Event Detection (SED) Targets

function T = extractSEDTargets(csvFile,params)
%EXTRACTSEDTARGETS Extract sound event detection (SED) targets
% T = extractSEDTargets(fileName,params) parses the CSV file
% fileName and returns a matrix of SED targets, T. The target matrix is an N-by-K
% matrix, where N corresponds to the number of frames and K corresponds to
% the number of sound classes.

% Preallocate target matrix. A frame of all zeros corresponds to no sound
% source.
T = zeros(params.Targets.NumFrames,params.SoundClasses.Count);

% Quantize the time stamps for sound sources into frames.
startendTime = [csvFile.Start,csvFile.End];
startendFrame = time2frame(startendTime,params.Targets.TotalDuration,params.Targets.NumFrames);

% For each sound source, fill the appropriate column of the target matrix
% with a 1, indicating that the sound class is present in that frame.
for ii = 1:size(startendFrame,1)
    classID = params.SoundClasses(csvFile.Class{ii});
    T(startendFrame(ii,1):startendFrame(ii,2)-1,classID) = 1;
end

% Wrap the target in a cell array for convenient batch processing.
T = {T};
end

Short-Time Fourier Transform (STFT)

function X = extractSTFT(s,params)
%EXTRACTSTFT Extract log-magnitude of centered STFT
% X = extractSTFT({s1,s2},params) concatenates s1 and s2 and then
% extracts the one-sided log-magnitude STFT. The signals are padded before
% the STFT so that the first window is centered on the first sample. The
% output is trimmed to remove the 1st (DC) coefficient and the last
% spectrum. The input params defines the STFT.

% Concatenate the signals along the second (channel) dimension.
audio = cat(2,s{:});

% Extract the centered STFT.
N = numel(params.SED.Window);
overlapLength = N - params.SED.HopLength;
S = centeredSTFT(audio,params.SED.Window,overlapLength,N);

% Trim the 1st coefficient from all spectrums and trim the last spectrum.
S = S(2:end,1:end-1,:);

% Convert to log-magnitude. Use an offset to protect against log of zero.
mag = log(abs(S) + eps);

% Cast output to single precision.
X = single(mag);
end

Generalized Cross Correlation with Phase Transform (GCC-PHAT)

function X = extractGCCPHAT(s,params)
%EXTRACTGCCPHAT Extract generalized cross correlation phase transform (GCC-PHAT) features
% X = extractGCCPHAT({s1,s2},params) concatenates s1 and s2 and then
% extracts the GCC-PHAT for all pairs of channels.

% Concatenate the signals corresponding to the two microphones.
audio = cat(2,s{:});

% Count the total number of input channels.
nChan = size(audio,2);

% Calculate the total number of output channels.
numOutputChannels = nchoosek(nChan,2);

% Preallocate a NumFeatures-by-NumFrames-by-NumChannels feature (predictor)
% matrix.
numFrames = size(audio,1)/params.DOA.HopLength;
X = zeros(params.DOA.NumBands,numFrames,numOutputChannels);

% -----------------------------------
% Calculate GCC-PHAT for each pair of channels.
% Precompute STFT for each channel.
N = numel(params.DOA.Window);
overlapLength = N - params.DOA.HopLength;
micAB_stft = centeredSTFT(audio,params.DOA.Window,overlapLength,N);
conjmicAB_stft = conj(micAB_stft(:,:,2:end));
idx = 1;
for ii = 1:nChan - 1
    R = micAB_stft(:,:,ii).*conjmicAB_stft(:,:,ii:end);
    R = exp(1i .* angle(R));
    R = padarray(R, N/2 - 1,"post");
    gcc = fftshift(ifft(R,[],1,"symmetric"),1);
    X(:,:,idx:idx+size(R,3)-1) = gcc(floor(N/2+1 - (params.DOA.NumBands-1)/2):floor(N/2+1 + (params.DOA.NumBands-1)/2),1:end-1,:);

    idx = idx + size(R,3);
end
% -----------------------------------

% Cast output to single precision.
X = single(X);

end

Centered Short-Time Fourier Transform (STFT)

function s = centeredSTFT(audio,win,overlapLength,fftLength)
%CENTEREDSTFT Centered STFT
% s = centeredSTFT(audioIn,win,overlapLength,fftLength) computes an STFT
% with the first window centered around the first sample. The two ends are
% padded with the reflected audio signal.

% Pad front and back of input signal.
firstR = flip(audio(1:fftLength/2,:),1);
lastR = flip(audio(end - fftLength/2 + 1:end,:),1);
sig = cat(1,firstR,audio,lastR);

% Perform STFT.
s = stft(sig,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");

end

Convert Time Stamp to Frame Number

function fnum = time2frame(t,dur,numFrames)
%TIME2FRAME Convert time stamp to frame number
% fnum = time2frame(t,dur,numFrames) maps the times t, which exist in dur,
% to a frame number if dur is divided into numFrames.

stp = dur/numFrames;

qt = round(t./stp).*stp;

fnum = floor(qt*(numFrames - 1)/dur) + 1;

end

Forward Pass Through CNN and RNN Networks

function [loss,cnnState,rnnState,Y3]  = forwardAll(model,X,T)
%FORWARDALL Forward pass of model through CNN and RNN networks
% [loss,cnnState,rnnState] = forwardAll(model,X,T) passes the predictors X
% through the model and returns the loss and the states of the networks in
% the model. The model is a struct containing a CNN network and an RNN
% network.
%
% [loss,cnnState,rnnState,Y] = forwardAll(model,X,T) also returns the final
% prediction of the model Y.

% Pass predictors through CNN.
[Y1,cnnState] = forward(model.CNN,X);

% Label the dimensions output from the CNN for consumption by the RNN.
Y2 = dlarray(Y1,"TCUB");

% Pass the predictors through the RNN.
[Y3,rnnState] = forward(model.RNN,Y2);

% Calculate the loss.
loss = seldNetLoss(Y3,T);

end

Full Model Prediction

function [loss,Y3]  = predictAll(model,X,T)
%PREDICTALL Model prediction through CNN and RNN networks
% [loss,prediction] = predictAll(model,X,T) passes the predictors X through
% the model and returns the loss and the model prediction. The model is a
% struct containing a CNN network and an RNN network.

% Pass predictors through CNN.
Y1 = predict(model.CNN,X);

% Label the dimensions output from the CNN for consumption by the RNN.
Y2 = dlarray(Y1,"TCUB");

% Pass the predictors through the RNN.
Y3 = predict(model.RNN,Y2);

% Calculate the loss.
loss = seldNetLoss(Y3,T);

end

Predict Batch

function loss = predictBatch(model,mbq)
%PREDICTBATCH Calculate the loss of mini-batch queue
% loss = predictBatch(model,mbq) returns the total loss calculated by
% passing the entire contents of the mini-batch queue through the model.

% Reset mini-batch queue and initialize counters.
reset(mbq)
loss = 0;
n = 0;

while hasdata(mbq)

    % Read the predictors and targets from mini-batch queue.
    [X,T] = next(mbq);

    % Pass the mini-batch through the model and calculate the loss.
    lss = predictAll(model,X,T);
    lss = lss/size(T,2);

    % Update the total loss.
    loss = loss + lss;

    % Sum number of datapoints.
    n = n + 1;

end

% Divide the total loss accumulated by the number of mini-batches.
loss = loss/n;

end

Compute Model Loss, Gradients, and Network States

function [loss,gradients,state] = modelLoss(model,X,T)
%MODELLOSS Compute model loss, gradients, and network states
% [loss,gradients,state] = modelLoss(model,X,T) passes the
% predictors X through the model and returns the loss, the gradients, and
% the states of the networks in the model. The model is a struct containing
% a CNN network and an RNN network.

% Pass the predictors through the model.
[loss,cnnState,rnnState] = forwardAll(model,X,T);

% Isolate the learnables.
allGrad.CNN = model.CNN.Learnables;
allGrad.RNN = model.RNN.Learnables;

state.CNN = cnnState;
state.RNN = rnnState;

% Calculate the gradients.
gradients = dlgradient(loss,allGrad);

end

Loss Function of SELDnet

function loss = seldNetLoss(Y,T)
%SELDNETLOSS Compute the SELDnet loss function for DOA or SED models
% loss = seldNetLoss(Y,T) returns the SELDnet loss given predictions Y and
% targets T. The loss function depends on the network (DOA or SED). The
% network is inferred by the dimensions of the target. For the DOA network,
% the loss function is mean-squared error. For the SED network, the loss
% function is crossentropy.

% Determine whether the targets correspond to the DOA network or SED
% network.
isDOAModel = size(T,find(dims(T)=='C'))==3;

if isDOAModel
    % Calculate MSE loss.
    doaLoss = mse(Y,T);
    doaLossFactor = 2 / (size(Y,1) * size(Y,3));
    loss = doaLoss * doaLossFactor; % To align with the original implementation
else
    % Calculate cross-entropy loss.
    loss = crossentropy(Y,T,ClassificationMode="multilabel",NormalizationFactor="all-elements");
end

loss = loss * size(T,2);

end

Compute Performance Metrics

function [r,cm] = computeMetrics(sedModel,doaModel,sedMBQ,doaMBQ,params,nvargs)
%COMPUTEMETRICS Compute performance metrics
% [r,cm] = computeMetrics(sedModel,doaModel,sedMBQ,doaMBW,params) returns
% a struct of performance metrics calculated over the SED and DOA
% validation mini-batch queues, and a confusion matrix cm valid SED
% regions.
arguments
    sedModel
    doaModel
    sedMBQ
    doaMBQ
    params
    nvargs.ApplySmoothing = false;
end

% Initialize counters.
TP = 0;
FP = 0;
FN = 0;
it = 0;
ct = 0;
err = 0;

sedYAll = [];
sedTAll = [];

% Loop over all the data.
reset(sedMBQ)
reset(doaMBQ)
while hasdata(sedMBQ)

    % Get the predictors, targets, and predictions for the SED model.
    [sedXb,sedTb] = next(sedMBQ);
    [~,sedYb]  = predictAll(sedModel,sedXb,sedTb);
    sedTb = extractdata(gather(sedTb));
    sedYb = extractdata(gather(sedYb));

    % Get the predictors, targets, and predictions for the DOA model.
    [doaXb,doaTb] = next(doaMBQ);
    [~,doaYb]  = predictAll(doaModel,doaXb,doaTb);
    doaTb = extractdata(gather(doaTb));
    doaYb = extractdata(gather(doaYb));
    doaYb = doaYb*params.DOA.ScaleFactor;
    doaTb = doaTb*params.DOA.ScaleFactor;

    % Loop over the mini-batches.
    for batch = 1:size(sedYb,2)

        % Isolate the predictors and targets for current data point.
        sedY = squeeze(sedYb(:,batch,:));
        sedT = squeeze(sedTb(:,batch,:));
        doaY = squeeze(doaYb(:,batch,:));
        doaT = squeeze(doaTb(:,batch,:));

        % If the SED predictions of a frame are all made with low
        % confidence (beneath a threshold), assume that there is no sound
        % source present.
        isActive = ~(sum(double(sedY<params.SilenceThreshold),1)==size(sedY,1));

        % Convert the SED predictors and targets from one-hot vectors to
        % scalars.
        [~,sedY] = max(sedY,[],1);
        sedY = sedY.*isActive;

        [isActive,sedT] = max(sedT,[],1);
        sedT = sedT.*isActive;

        % Smooth outputs.
        if nvargs.ApplySmoothing
            [doaY,sedY] = smoothOutputs(doaY,sedY,params);
        end

        % Perform location-sensitive detection.
        [tp,fp,fn,e,c] = locationSensitiveDetection(sedY,sedT,doaY,doaT,params);
        
        % Accumulate performance metrics.
        TP = TP + tp;
        FP = FP + fp;
        FN = FN + fn;
        err = err + e;
        ct = ct + c;

        sedYAll = [sedYAll sedY.*isActive]; %#ok<AGROW> 
        sedTAll = [sedTAll sedT.*isActive]; %#ok<AGROW> 
    end
    it = it + 1;
end

% Calculate performance metrics.
r.precision =  TP/(TP + FP + eps);
r.recall = TP / (TP + FN + eps);
r.f1Score = 2*(r.precision*r.recall)/(r.precision + r.recall + eps);
r.avgErr = err/ct;

% Calculate confusion matrix.
confmat = confusionmat(sedTAll,single(sedYAll),Order=0:14);
cm = confmat(2:end,2:end); % Remove the silence from the confusion matrix.
end

Location Sensitive Detection

function [TP,FP,FN,totErr,ct] = locationSensitiveDetection(sedY,sedT,doaY,doaT,params)
%LOCATIONSENSITIVEDETECTION Location sensitive detection
% [TP,FP,FN,totErr,ct] =
% locationSensitiveDetection(sedY,sedT,doaY,doaT,params) calculates the
% true positive, false positive, false negative, DOA total error, and
% number of active targets. The definitions of each metric are provided in
% [4].

% Calculate distance.
dist = vecnorm(doaY-doaT);

% Determine if sounds active for reference and predictions.
isReferenceActive = sedT~=0;
isPredictedActive = sedY~=0;

% Calculate the total DOA error for reference-active sections.
totErr = sum(dist.*isReferenceActive);

% Count total number of active targets.
ct = sum(isReferenceActive);

% Determine if the DOA is within threshold per frame.
isDOAnear = dist < params.SpatialThreshold;

% True positive: 
TP = sum(isDOAnear & isReferenceActive & isPredictedActive & (sedT==sedY));

% False positive: 
FP1 = sum(~isReferenceActive & isPredictedActive);
FP2 = sum(isReferenceActive & isPredictedActive & (sedT~=sedY | ~isDOAnear));
FP = FP1 + FP2;

% False negative:
FN1 = sum(isReferenceActive & ~isPredictedActive);
FN2 = sum(isReferenceActive & (sedT~=sedY | ~isDOAnear));
FN = FN1 + FN2;

end

Smooth Outputs

function [doaYSmooth,sedYSmooth] = smoothOutputs(doaY,sedY,params)
%SMOOTHOUTPUTS Smooth DOA and SED predictions over time
% [doaYSmooth,sedYSmooth] = smoothOutputs(doaY,sedY,params) smooths the DOA
% and SED predictions over time.

% Preallocate smoothed outputs.
doaYSmooth = doaY;
sedYSmooth = sedY;

% Cluster the DOA predictions.
clusters = clusterdata(doaY',Criterion="distance",Cutoff=params.SpatialThreshold);
stt = 1;
enn = 1;

while enn <= params.Targets.NumFrames

    if clusters(stt) == clusters(enn)
        enn = enn + 1;
    else
        doaYSmooth(:,stt:enn-1) = smoothDOA(doaY(:,stt:enn-1));
        sedYSmooth(:,stt:enn-1) = smoothSED(sedY(:,stt:enn-1));
        stt = enn;
    end

end

doaYSmooth(:,stt:enn-1) = smoothDOA(doaY(:,stt:enn-1));
sedYSmooth(:,stt:enn-1) = smoothSED(sedY(:,stt:enn-1));

sedYSmooth = round(movmedian(sedYSmooth,5));

end

Smooth DOA Prediction

function smoothed = smoothDOA(chunk)
%SMOOTHDOA Smooth DOA prediction
% smoothed = smoothDOA(chunk) smooths DOA predictions by replacing the
% values of each axis with the mean of that axis in the chunk. The mean is
% calculated after discarding the lower and upper quarters of data.

% Determine the length of the chunk, and then indices to cut out the middle
% half of the data.
chlen = size(chunk,2);
st = max(round(chlen*1/4),1);
en = max(round(chlen*3/4),1);

% Sort the spatial axes (columns).
dim = sort(chunk,2);

% Take the mean of the inner half.
smoothed = repmat(mean(dim(:,st:en),2),1,chlen);

end

Smooth SED Prediction

function smoothed = smoothSED(chunk)
%SMOOTHSED Smooth SED prediction
% smoothed = smoothSED(chunk) smooths SED predictions using the mode.

smoothed = repmat(mode(chunk),1,size(chunk,2));

end