メインコンテンツ

Joint Time-Frequency Scattering

A joint time-frequency scattering (JTFS) network enables you to extract features from a signal that are invariant to shifts or deformations in both time and frequency. Time and frequency invariance makes JTFS features robust inputs in AI classification workflows. For more information on using JTFS in such workflows, see Acoustic Scene Classification with Wavelet Scattering and Musical Instrument Classification with Joint Time-Frequency Scattering.

Anden, Lostanlen, and Mallat developed the JTFS transform as an extension of wavelet time scattering [1]. The wavelet time scattering transform filters data along the time dimension, and then applies pointwise modulus nonlinearities. The JTFS transform additionally filters the data along the frequency dimension, followed by pointwise modulus nonlinearities.

In this example, you learn how to extend wavelet time scattering into the JTFS framework. The time wavelets are related to the separable time-frequency wavelets found in JTFS. You will visualize those wavelets and learn how they are sensitive to the time-frequency geometry of a signal. You also demonstrate the equivalence of the time lowpass filter in JTFS with the lowpass filter in wavelet time scattering.

Wavelet Time Scattering

A wavelet (time) scattering network enables you to derive low-variance features from time series data to use in AI applications. The features are insensitive to translations in time on a scale that you can specify.

In a wavelet scattering network, the input data is first convolved with wavelet filters. Then, a pointwise modulus nonlinearity is applied to the filter bank outputs. The result is the scalogram of the input data. The scalogram is then smoothed (averaged) with a lowpass (scaling) filter. The process repeats for the number of filter banks specified in the network. For more information, see Wavelet Scattering. This is a tree view of a network with two filter banks.

In the tree view:

  • x is the input time series.

  • T denotes convolution in time.

  • ψλ(1) and ψμ(2) are the (time) wavelets in the first- and second-order filter banks, respectively. The subscripts λ and μ correspond to the wavelet center frequencies.

  • φT is the scaling function.

By default, waveletScattering creates a scattering network with two filter banks.

The first-order scalogram coefficients

U1x(t,λ)=|xTψλ(1)(t)|

are the absolute magnitude of the continuous wavelet transform (CWT) of the input data. The first-order scattering coefficients

S1x(t,λ)=U1x(t,λ)TφT(t)

are the smoothed scalogram coefficients. The first-order scalogram and scattering coefficients are a subset of the coefficients returned by the waveletScattering object function scatteringTransform.

Joint Time-Frequency Scattering

The JTFS transform is inspired by biology. Research into the primary auditory cortex has demonstrated the existence of spectro-temporal receptive fields (STRFs) at the cortical level. The STRF of a neuron is a function of time and frequency, which describes its post-stimulus time histogram in response to various stimuli. STRFs at the cortical level exhibit ripple-like responses around a given point (t,λ) in the time-frequency domain. This ripple-like behavior can be described in terms of a temporal modulation rate, μ, and a frequency modulation rate, . The temporal modulation rate is in hertz and the frequency modulation rate is in cycles per octave or quefrency. Accordingly, the full cortical model requires four parameters: (t,λ,μ,).

The wavelets ψμ(t)=1μψ(tμ) are used to obtain the second-order time scattering coefficients:

S2x(t,λ,μ)=||xTψλ(1)|Tψμ(2)|φT(t)=U2x(t,λ,μ)TφT(t).

These scattering coefficients are the smoothed second-order scalogram coefficients. Both sets of coefficients are outputs of the waveletScattering object function scatteringTransform.

Joint time-frequency scattering takes wavelet time scattering one step further and creates a feature extractor that mimics the full cortical model. The second-order scattering coefficients depend on μ, the center frequencies of the second-order time wavelets, ψμ(2)(t). To account for the time-frequency ripple-like responses, JTFS uses a two-dimensional separable wavelet:

Ψμ,,s(t,λ)=ψμ(2)(t)ψ,s(sλ).

The ψ,s(sλ) wavelets are frequential wavelets. Recall the λ are the center frequencies of the first-order time wavelets and the superscript (2) here denotes that this wavelet is used in the second-order time scattering. The variable s is the so-called spin which takes the values ±1. Therefore, with respect to frequency modulations, the separable wavelet permits both positive and negative frequencies whereas the time wavelets are typically analytic.

Similar to the case with time scattering, in JTFS, you apply a joint time-frequency lowpass filter φT,F. The JTFS coefficients are defined as:

S(t,λ,μ,,s)=||xTψλ(1)|T,FΨμ,,s|T,FφT,F.

Visualize JTFS Separable 2-D Wavelet

Create a JTFS network.

jtfn = timeFrequencyScattering;

Use the filterbank object function to obtain the second-order time wavelet filter bank and its metadata.

[~,psi2f,~,timemeta] = filterbank(jtfn);
centerFrequency = timemeta{2}.xi;
minCenterFrequency = min(centerFrequency(end));
whichCF = minCenterFrequency; %#ok<*NASGU>

Use the same function to obtain the spin-up and spin-down wavelets and their metadata. The output variable frequencymeta contains the metadata for both the spin-up and spin-down wavelets.

[psifup,psifdown,~,frequencymeta] = filterbank(jtfn, ...
    FilterBank="frequency");
centerQuefrency = frequencymeta.xi;
minCenterQuefrency = min(centerQuefrency);
whichCQ = minCenterQuefrency;

A JTFS separable 2-D wavelet Ψμ,,s(t,λ) is defined as Ψμ,,s(t,λ)=ψμ(2)(t)ψ,s(sλ), where ψμ(2)(t) is a second-order time wavelet with center frequency (time modulation rate) μ, and ψ,s(sλ) is a frequential wavelet with center quefrency (frequency modulation rate) and spin s. The frequential wavelet ψ,1 is a spin-up wavelet, and ψ,-1 is a spin-down wavelet.

Select the center frequency of a time wavelet and the center quefrency of either a spin-up or spin-down wavelet. Use the helper function helperPlotSeparableWavelet to plot the real part of the separable 2-D wavelet associated with the time and frequential wavelets. To choose a spin-down wavelet, select a negative whichCQ value. You can use the same helper function to plot the imaginary part and magnitude of Ψμ,,s. The source code for this helper function is in the same folder as this example file.

whichCF = centerFrequency(9); % time wavelet center frequency
whichCQ = centerQuefrency(10); % frequential wavelet center quefrency
plotType = "real";
helperPlotSeparableWavelet(psi2f,timemeta,psifup,psifdown,frequencymeta,whichCF,whichCQ,plotType)

Figure contains an axes object. The axes object with title Separable Wavelet: Real Part Center Frequency mu : 0.0015625 cycles/sample Center Quefrency ell : - 0.025 cycles/octave Spin s : - 1 contains an object of type image.

JTFS and Sensitivity to Time-Frequency Geometry

Create a signal that consists of a quadratic chirp and two sinusoids with disjoint time support. Each sinusoid has a different frequency. Sample the signal at 2000 Hz for four seconds.

tspan = 4;
Fs = 2e3;
t = 0:1/Fs:tspan-1/Fs;
y = chirp(tspan/2-t,30,max(tspan/2-t),200,"quadratic",[],"concave");
sig1 = cos(5*2*pi*t);
sig2 = cos(450*2*pi*t);
sig1 = sig1.*(t<tspan/2);
sig2 = sig2.*(t>=tspan/2);
sig = sig1+sig2+y;
plot(t,sig)
title("Signal")
xlabel("Time (s)")
ylabel("Amplitude")

Figure contains an axes object. The axes object with title Signal, xlabel Time (s), ylabel Amplitude contains an object of type line.

Use the cwt function to plot the scalogram of the signal. The signal has a nontrivial time-frequency geometry.

cwt(sig,Fs)

Figure contains an axes object. The axes object with title Magnitude Scalogram, xlabel Time (secs), ylabel Frequency (Hz) contains 3 objects of type image, line, area.

One characteristic of JTFS coefficients is that different sets of coefficients are sensitive to different parts of the time-frequency geometry. To see this, use the timeFrequencyScattering function to create a JTFS network appropriate for the signal. Use the scatteringTransform function to obtain the JTFS of the signal. The five sets of JTFS coefficients are in the dictionary outCFS, and the metadata describing each set is in the cell array outMETA.

sigLength = length(sig);
jtfn = timeFrequencyScattering(SignalLength=sigLength, ...
    TimeInvarianceScale=16, ...
    TimeQualityFactors=[16 1], ...
    TimeMaxPaddingFactor=2, ...
    FrequencyInvarianceScale=1, ...
    NumFrequencyOctaves=2, ...
    FrequencyQualityFactor=2, ...
    FrequencyMaxPaddingFactor=2);

[outCFS,outMETA] = scatteringTransform(jtfn,sig);

Use the scattergram function to visualize the spin-up and spin-down coefficients. A JTFS separable 2-D wavelet is the product of a second-order time wavelet and a frequential wavelet of spin 1 or -1. Each row in the plot corresponds to a frequential wavelet of a given center quefrency. Spin-up and spin-down wavelets have positive and negative center quefrencies, respectively. Each column corresponds to a second-order time wavelet of a given center frequency.

The spin-up coefficients preferentially localize the up-chirp portion of the quadratic chirp, and the spin-down coefficients preferentially localize the down-chirp portion. Separable wavelets with lower temporal rates and higher modulation rates tend to localize the sinusoids. Separable wavelets with higher temporal rates and lower modulation rates tend to localize the discontinuity between the sinusoids.

scattergram(jtfn,outCFS,outMETA,PlotType="Spinned")

Figure contains 100 axes objects. Axes object 1 with xlabel 0.1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 with xlabel 0.05 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 with xlabel 0.025 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 with xlabel 0.013 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image. Axes object 21 with xlabel 0.0063 contains an object of type image. Axes object 22 contains an object of type image. Axes object 23 contains an object of type image. Axes object 24 contains an object of type image. Axes object 25 contains an object of type image. Axes object 26 with xlabel 0.0031 contains an object of type image. Axes object 27 contains an object of type image. Axes object 28 contains an object of type image. Axes object 29 contains an object of type image. Axes object 30 contains an object of type image. Axes object 31 with xlabel 0.0016 contains an object of type image. Axes object 32 contains an object of type image. Axes object 33 contains an object of type image. Axes object 34 contains an object of type image. Axes object 35 contains an object of type image. Axes object 36 with xlabel 0.00078 contains an object of type image. Axes object 37 contains an object of type image. Axes object 38 contains an object of type image. Axes object 39 contains an object of type image. Axes object 40 contains an object of type image. Axes object 41 with xlabel 0.00039 contains an object of type image. Axes object 42 contains an object of type image. Axes object 43 contains an object of type image. Axes object 44 contains an object of type image. Axes object 45 contains an object of type image. Axes object 46 with xlabel 0.0002, ylabel +0.41 contains an object of type image. Axes object 47 with ylabel +0.29 contains an object of type image. Axes object 48 with ylabel +0.21 contains an object of type image. Axes object 49 with ylabel +0.14 contains an object of type image. Axes object 50 with ylabel +0.069 contains an object of type image. Axes object 51 contains an object of type image. Axes object 52 contains an object of type image. Axes object 53 contains an object of type image. Axes object 54 contains an object of type image. Axes object 55 contains an object of type image. Axes object 56 contains an object of type image. Axes object 57 contains an object of type image. Axes object 58 contains an object of type image. Axes object 59 contains an object of type image. Axes object 60 contains an object of type image. Axes object 61 contains an object of type image. Axes object 62 contains an object of type image. Axes object 63 contains an object of type image. Axes object 64 contains an object of type image. Axes object 65 contains an object of type image. Axes object 66 contains an object of type image. Axes object 67 contains an object of type image. Axes object 68 contains an object of type image. Axes object 69 contains an object of type image. Axes object 70 contains an object of type image. Axes object 71 contains an object of type image. Axes object 72 contains an object of type image. Axes object 73 contains an object of type image. Axes object 74 contains an object of type image. Axes object 75 contains an object of type image. Axes object 76 contains an object of type image. Axes object 77 contains an object of type image. Axes object 78 contains an object of type image. Axes object 79 contains an object of type image. Axes object 80 contains an object of type image. Axes object 81 contains an object of type image. Axes object 82 contains an object of type image. Axes object 83 contains an object of type image. Axes object 84 contains an object of type image. Axes object 85 contains an object of type image. Axes object 86 contains an object of type image. Axes object 87 contains an object of type image. Axes object 88 contains an object of type image. Axes object 89 contains an object of type image. Axes object 90 contains an object of type image. Axes object 91 contains an object of type image. Axes object 92 contains an object of type image. Axes object 93 contains an object of type image. Axes object 94 contains an object of type image. Axes object 95 contains an object of type image. Axes object 96 with ylabel -0.41 contains an object of type image. Axes object 97 with ylabel -0.29 contains an object of type image. Axes object 98 with ylabel -0.21 contains an object of type image. Axes object 99 with ylabel -0.14 contains an object of type image. Axes object 100 with ylabel -0.069 contains an object of type image.

You can visualize the separable 2-D wavelet associated with a set of spinned coefficients. First, use the filterbank function to extract the second-order time wavelets and spinned wavelets from the JTFS network. Also obtain their metadata.

[~,psi2f,~,timemeta] = filterbank(jtfn);
[psifup,psifdown,~,frequencymeta] = filterbank(jtfn,FilterBank="frequency");

The variables outMETA{3} and outMETA{4} are tables describing the spin-up and spin-down coefficients, respectively. Each table row corresponds to a subplot in the scattergram, and each subplot corresponds to a specific separable 2-D wavelet. Inspect the spin-up metadata. The path table variable indicates the coefficient path. The first column is the index of the frequential wavelet, and the second column is the index of the second-order time wavelet.

outMETA{3}
ans=50×5 table
      type      log2dsfactor     path     spin    log2stride
    ________    ____________    ______    ____    __________

    "SpinUp"       0    1       1    3     1        0    4  
    "SpinUp"       0    1       2    3     1        0    4  
    "SpinUp"       0    1       3    3     1        0    4  
    "SpinUp"       0    1       4    3     1        0    4  
    "SpinUp"       0    1       5    3     1        0    4  
    "SpinUp"       0    2       1    4     1        0    4  
    "SpinUp"       0    2       2    4     1        0    4  
    "SpinUp"       0    2       3    4     1        0    4  
    "SpinUp"       0    2       4    4     1        0    4  
    "SpinUp"       0    2       5    4     1        0    4  
    "SpinUp"       0    3       1    5     1        0    4  
    "SpinUp"       0    3       2    5     1        0    4  
    "SpinUp"       0    3       3    5     1        0    4  
    "SpinUp"       0    3       4    5     1        0    4  
    "SpinUp"       0    3       5    5     1        0    4  
    "SpinUp"       0    4       1    6     1        0    4  
      ⋮

You can use the path variable to access in timemeta and frequencymeta the metadata that describes the time and frequential wavelets. Choose a path value from one of the rows in outMETA{3}. Use that value to display the metadata describing the frequential and second-order time wavelets associated with the separable wavelet. The xi variable contains the center quefrency and center frequency of the frequential and time wavelets, respectively. Metadata describing the second-order time filter bank is in timemeta{2}.

pathValue = [3 10];
frequencymeta(pathValue(1),:)
ans=1×7 table
      xi        sigma      isCQT    log2dsfactor    spin    peakidx     bwidx  
    _______    ________    _____    ____________    ____    _______    ________

    0.20711    0.042681      1           0           1        32       14    49

timemeta{2}(pathValue(2),:)
ans=1×6 table
        xi          sigma       isCQT    log2dsfactor    peakidx     bwidx 
    __________    __________    _____    ____________    _______    _______

    0.00078125    0.00031279      1           8            13       2    25

Use the helper function helperPlotJTFSWaveletAndCFS to plot the separable wavelet and the coefficients associated with it. The source code for this helper function is in the same folder as this example file.

figure
plotType = "real";
helperPlotJTFSWaveletAndCFS(pathValue,psifup,frequencymeta,psi2f,timemeta,outCFS{"SpinUp"},outMETA{3},plotType)

Figure contains 2 axes objects. Axes object 1 with title Separable Wavelet: Real Part contains an object of type image. Axes object 2 with title JTFS Coefficients contains an object of type image.

Equivalence of JTFS and Wavelet Time Scattering Lowpass Filters

Joint time-frequency scattering is an extension of wavelet time scattering. This example shows the equivalence of the time lowpass filter in a JTFS network with the lowpass filter in a wavelet time scattering network.

Create a wavelet time scattering network. Specify a signal length of 216 samples and single precision. If the signal intended for the network is sampled at 10 kHz, then one second of data corresponds to 10,000 samples. Set the invariance scale to 10,000 samples.

slen = 2^16;
invscale = 1e4;
tsn = waveletScattering(SignalLength=slen, ...
    Precision="single", ...
    InvarianceScale=invscale);

Use the time scattering filterbank object function to obtain the lowpass (scaling) filter.

tsnfb = filterbank(tsn);
phifTS = tsnfb{1}.phift;

The time invariance scale determines the size of the time support of the lowpass filter. The scalogram coefficients are smoothed with a lowpass filter to obtain the scattering coefficients. So the time invariance scale determines how much smoothing in time is done.

To show the equivalence of the two lowpass filters, you need to determine the relationship between the time invariance scale in waveletScattering and the time invariance scale in timeFrequencyScattering. A literature search shows there is no single definition for the time invariance scale. In this example, for a given time invariance scale in wavelet time scattering, you learn how to determine the corresponding time invariance scale in JTFS. In other words, you learn how to perform equivalent smoothing, in time, in wavelet time scattering and JTFS.

To begin, for the normal distribution, 99.5% of the probability is localized in the interval [-2.5758σt,2.5758σt], where σt is the time standard deviation. In wavelet time scattering, which uses the Gaussian lowpass filter ϕ(t)=12πσt2exp(-t22σt2), the invariance scale T corresponds to ±2.5758σt, that is, T=(22.5758)σt. You can rewrite the equation to express the time standard deviation in terms of the invariance scale: σt=T22.5758.

In JTFS, consider the Gaussian lowpass filter ϕ(t) as expressed in the Fourier domain: ϕˆ(ξ)=exp(-ξ22σξ2), where σξ2=14π2σt2 which simplifies to σξ=12πσt. The presence of the '2π' is because JTFS uses cyclic frequency. ϕˆ(ξ) is constructed on the frequency grid (-1/2,1/2). Because the JTFS invariance scale τ is greater than or equal to 1, JTFS defines the frequency standard deviation as σξ=0.13/τ. For the purposes of this example, 0.13 is called the base sigma. Suppose τ=1. Then ϕˆ(ξ)=exp(-ξ22(0.13)2). Because 1/0.13 is approximately 7.69, this means that the interval (-1/2,1/2) contains almost 8 standard deviations, ensuring that ϕˆ(ξ) is completely decayed. You can choose a smaller base sigma, but there would be negligible practical benefit. Note that as τ gets larger, the number of standard deviations in the interval grows, meaning the tails go to 0 faster. The Gaussian filter becomes more narrow in the Fourier domain, or equivalently, more broad in the time domain. Because σξ=12πσt and σξ=0.13/τ, you can express the time standard deviation in terms of the JTFS invariance scale: σt=12πσξ=12π(0.13τ)=τ2π0.13.

So there are two equivalent expressions for the time standard deviation:

  • Time Scattering: σt(TS)=T22.5758

  • JTFS: σt(JTFS)=τ2π0.13

By equating the two expressions:

T22.5758=τ2π0.13

you can determine the JTFS invariance scale τ that is equivalent to a given time-scattering invariance scale T: τ=2π0.1322.5758T.

In the time scattering network, the invariance scale T is 10,000 samples. Determine the equivalent time invariance scale τ in JTFS. Round the result to obtain an integer.

basesigma = 0.13;
tau = round((2*pi*basesigma)/(2*2.5758)*invscale)
tau = 
1586

Create a JTFS network with a signal length of 216. Set the time invariance scale to tau . Specify single precision and a base-2 time maximum padding factor of 0. (For more information about the padding factor, see filterpadding.) Use the JTFS filterbank object function to obtain the time lowpass filter.

jtfsn = timeFrequencyScattering(SignalLength=slen, ...
    TimeInvarianceScale=tau, ...
    FilterDataType="single", ...
    TimeMaxPaddingFactor=0);
[~,~,phifJTFS] = filterbank(jtfsn);

Plot both lowpass filters in the time domain. Confirm the filters are equivalent.

phitTS = ifftshift(ifft(phifTS));
phitJTFS = ifftshift(ifft(phifJTFS));
t = -slen/2:slen/2-1;
plot(t,[phitTS phitJTFS]);
grid on
axis tight
xlabel("Samples")
ylabel("Amplitude")
title(["Time Lowpass Filters for"; ...
    "Time and Joint Time-Frequency Scattering"])
legend("Time Scattering","JTFS")

Figure contains an axes object. The axes object with title Time Lowpass Filters for Time and Joint Time-Frequency Scattering, xlabel Samples, ylabel Amplitude contains 2 objects of type line. These objects represent Time Scattering, JTFS.

References

[1] Andén, Joakim, Vincent Lostanlen, and Stéphane Mallat. “Joint Time–Frequency Scattering.” IEEE Transactions on Signal Processing 67, no. 14 (July 15, 2019): 3704–18. https://doi.org/10.1109/TSP.2019.2918992

[2] Lostanlen, Vincent, Christian El-Hajj, Mathias Rossignol, Grégoire Lafay, Joakim Andén, and Mathieu Lagrange. “Time–Frequency Scattering Accurately Models Auditory Similarities between Instrumental Playing Techniques.” EURASIP Journal on Audio, Speech, and Music Processing 2021, no. 1 (December 2021): 3. https://doi.org/10.1186/s13636-020-00187-z

[3] Mallat, Stéphane. “Group Invariant Scattering.” Communications on Pure and Applied Mathematics 65, no. 10 (October 2012): 1331–98. https://doi.org/10.1002/cpa.21413

See Also

Objects

Topics