メインコンテンツ

深層学習ネットワークを使用したカクテル パーティ ソースの分離

この例では、深層学習ネットワークを使用して音声信号を分離する方法を説明します。

はじめに

カクテル パーティ効果とは、脳が他の声やバックグラウンド ノイズを除去して 1 人の話者に集中できる能力のことをいいます。人間は、カクテル パーティ問題を非常にうまく処理します。この例では、1 人の男性と 1 人の女性が同時に発話している混合音声に対し、深層学習ネットワークを使用して個々の話者を分離する方法を説明します。

必要なファイルのダウンロード

例の詳細に入る前に、事前学習済みのネットワークと 4 つのオーディオ ファイルをダウンロードします。

downloadFolder = matlab.internal.examples.downloadSupportFile("audio/examples","cocktailpartyfc.zip");
dataFolder = tempdir;
dataset = fullfile(dataFolder,"CocktailPartySourceSeparation");
unzip(downloadFolder,dataset)

問題の概要

4 kHz でサンプリングされた男性と女性の音声を含むオーディオ ファイルを読み込みます。参考のために、オーディオ ファイルを別々に再生します。

[mSpeech,Fs] = audioread(fullfile(dataset,"MaleSpeech-16-4-mono-20secs.wav"));
sound(mSpeech,Fs)
[fSpeech] = audioread(fullfile(dataset,"FemaleSpeech-16-4-mono-20secs.wav"));
sound(fSpeech,Fs)

2 つの音声ソースを混合します。混合音声において各ソースが必ず同じ強度になるようにします。最大振幅が 1 となるように混合音声をスケーリングします。

mSpeech = mSpeech/norm(mSpeech);
fSpeech = fSpeech/norm(fSpeech);

ampAdj = max(abs([mSpeech;fSpeech]));
mSpeech = mSpeech/ampAdj;
fSpeech = fSpeech/ampAdj;

mix = mSpeech + fSpeech;
mix = mix./max(abs(mix));

元の信号と混合信号を可視化します。混合音声信号を再生します。この例では、混合音声から男性のソースと女性のソースを抽出するソース分離スキームを説明します。

t = (0:numel(mix)-1)*(1/Fs);

figure
tiledlayout(3,1)

nexttile
plot(t,mSpeech)
title("Male Speech")
grid on

nexttile
plot(t,fSpeech)
title("Female Speech")
grid on

nexttile
plot(t,mix)
title("Speech Mix")
xlabel("Time (s)")
grid on

Figure contains 3 axes objects. Axes object 1 with title Male Speech contains an object of type line. Axes object 2 with title Female Speech contains an object of type line. Axes object 3 with title Speech Mix, xlabel Time (s) contains an object of type line.

混合オーディオを再生します。

sound(mix,Fs)

時間-周波数表現

stft を使用して、男性音声信号、女性音声信号、および混合音声信号の時間-周波数 (TF) 表現を可視化します。長さ 128 のハン ウィンドウ、長さ 128 の FFT、および長さ 96 のオーバーラップを使用します。

windowLength = 128;
fftLength = 128;
overlapLength = 96;
win = hann(windowLength,"periodic");

figure
tiledlayout(3,1)

nexttile
stft(mSpeech,Fs,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
title("Male Speech")

nexttile
stft(fSpeech,Fs,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
title("Female Speech")

nexttile
stft(mix,Fs,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
title("Mix Speech")

Figure contains 3 axes objects. Axes object 1 with title Male Speech, xlabel Time (s), ylabel Frequency (kHz) contains an object of type image. Axes object 2 with title Female Speech, xlabel Time (s), ylabel Frequency (kHz) contains an object of type image. Axes object 3 with title Mix Speech, xlabel Time (s), ylabel Frequency (kHz) contains an object of type image.

理想的な時間-周波数マスクを使用したソース分離

TF マスクの適用は、競合する音声から目的のオーディオ信号を分離する効果的な手法であることが示されています。TF マスクは、基となる STFT と同じサイズの行列です。基となる STFT とこのマスクが要素ごとに乗算されて、目的のソースが分離されます。TF マスクは、バイナリ マスクまたはソフト マスクのいずれかです。

理想的なバイナリ マスクを使用したソースの分離

理想的なバイナリ マスクでは、マスクのセルの値が 0 または 1 のいずれかです。特定の TF セルにおいて、目的のソースのパワーが他のソースを組み合わせたパワーより大きい場合、そのセルは 1 に設定されます。そうでない場合、セルは 0 に設定されます。

男性話者の理想的なバイナリ マスクを計算して可視化します。

P_M = stft(mSpeech,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
P_F = stft(fSpeech,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
[P_mix,F] = stft(mix,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");

binaryMask = abs(P_M) >= abs(P_F);

figure
plotMask(binaryMask,windowLength - overlapLength,F,Fs)

Figure contains an axes object. The axes object with xlabel Time (s), ylabel Frequency (Hz) contains an object of type image.

混合音声の STFT と男性話者のバイナリ マスクを乗算し、男性音声の STFT を推定します。混合音声の STFT と男性話者のバイナリ マスクの逆数を乗算し、女性音声の STFT を推定します。

P_M_Hard = P_mix.*binaryMask;
P_F_Hard = P_mix.*(1-binaryMask);

逆短時間 FFT (ISTFT) を使用して、男性と女性のオーディオ信号を推定します。推定された信号と元の信号を可視化します。推定された男性と女性の音声信号を再生します。

mSpeech_Hard = istft(P_M_Hard,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
fSpeech_Hard = istft(P_F_Hard,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");

figure
tiledlayout(2,2)

nexttile
plot(t,mSpeech)
axis([t(1) t(end) -1 1])
title("Original Male Speech")
grid on

nexttile
plot(t,mSpeech_Hard)
axis([t(1) t(end) -1 1])
xlabel("Time (s)")
title("Estimated Male Speech")
grid on

nexttile
plot(t,fSpeech)
axis([t(1) t(end) -1 1])
title("Original Female Speech")
grid on

nexttile
plot(t,fSpeech_Hard)
axis([t(1) t(end) -1 1])
title("Estimated Female Speech")
xlabel("Time (s)")
grid on

Figure contains 4 axes objects. Axes object 1 with title Original Male Speech contains an object of type line. Axes object 2 with title Estimated Male Speech, xlabel Time (s) contains an object of type line. Axes object 3 with title Original Female Speech contains an object of type line. Axes object 4 with title Estimated Female Speech, xlabel Time (s) contains an object of type line.

sound(mSpeech_Hard,Fs)
sound(fSpeech_Hard,Fs)

理想的なソフト マスクを使用したソースの分離

ソフト マスクでは、TF マスクのセルの値が、混合音声の合計強度に対する目的のソースの強度の比に等しくなります。TF セルは、[0,1] の範囲の値をもちます。

男性話者のソフト マスクを計算します。混合音声の STFT と男性話者のソフト マスクを乗算し、男性話者の STFT を推定します。混合音声の STFT と女性話者のソフト マスクを乗算し、女性話者の STFT を推定します。

ISTFT を使用して、男性と女性のオーディオ信号を推定します。

softMask = abs(P_M)./(abs(P_F) + abs(P_M) + eps);

P_M_Soft = P_mix.*softMask;
P_F_Soft = P_mix.*(1-softMask);

mSpeech_Soft = istft(P_M_Soft,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
fSpeech_Soft = istft(P_F_Soft,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");

推定された信号と元の信号を可視化します。推定された男性と女性の音声信号を再生します。ここで、非常に良好な結果が得られるのは、分離した男性と女性の信号に関する十分な知識に基づいてマスクを作成したためであることに注意してください。

figure
tiledlayout(2,2)

nexttile
plot(t,mSpeech)
axis([t(1) t(end) -1 1])
title("Original Male Speech")
grid on

nexttile
plot(t,mSpeech_Soft)
axis([t(1) t(end) -1 1])
title("Estimated Male Speech")
grid on

nexttile
plot(t,fSpeech)
axis([t(1) t(end) -1 1])
xlabel("Time (s)")
title("Original Female Speech")
grid on

nexttile
plot(t,fSpeech_Soft)
axis([t(1) t(end) -1 1])
xlabel("Time (s)")
title("Estimated Female Speech")
grid on

Figure contains 4 axes objects. Axes object 1 with title Original Male Speech contains an object of type line. Axes object 2 with title Estimated Male Speech contains an object of type line. Axes object 3 with title Original Female Speech, xlabel Time (s) contains an object of type line. Axes object 4 with title Estimated Female Speech, xlabel Time (s) contains an object of type line.

sound(mSpeech_Soft,Fs)
sound(fSpeech_Soft,Fs)

深層学習を使用したマスクの推定

この例における深層学習ネットワークの目標は、前述の理想的なソフト マスクを推定することです。ネットワークは男性話者に対応するマスクを推定します。女性話者のマスクは男性マスクから直接導出されます。

標準の深層学習の学習スキームを以下に示します。予測子は、混合 (男性 + 女性) のオーディオの振幅スペクトルです。ターゲットは、男性話者に対応する理想的なソフト マスクです。回帰ネットワークは、予測子入力を使用して、出力と入力のターゲット間の平均二乗誤差を最小化します。出力において、出力の振幅スペクトルと混合信号の位相を使用してオーディオ STFT が時間領域に戻されます。

短時間フーリエ変換 (STFT) を使用して、ウィンドウの長さが 128 サンプル、オーバーラップが 127、ハン ウィンドウをもつ周波数領域にオーディオを変換します。負の周波数に対応する周波数サンプルを落として、スペクトル ベクトルのサイズを 65 まで縮小します (時間領域の音声信号が実数のため、これは情報の損失にはつながりません)。予測子の入力は、連続する 20 個の STFT ベクトルで構成されます。出力は 65 x 20 のソフト マスクです。

学習済みのネットワークを使用して、男性音声を推定します。学習済みのネットワークへの入力は、混合 (男性 + 女性) の音声オーディオです。

STFT のターゲットと予測子

この節では、学習データセットからターゲット信号と予測子信号を生成する方法を説明します。

それぞれ 4 kHz でサンプリングされた、男性話者と女性話者の約 400 秒の音声から成る学習信号を読み取ります。学習の速度を上げるため、低いサンプル レートを使用します。学習信号をトリミングして同じ長さにします。

mSpeechTrain = audioread(fullfile(dataset,"MaleSpeech-16-4-mono-405secs.wav"));
fSpeechTrain = audioread(fullfile(dataset,"FemaleSpeech-16-4-mono-405secs.wav"));

L = min(length(mSpeechTrain),length(fSpeechTrain));  
mSpeechTrain = mSpeechTrain(1:L);
fSpeechTrain = fSpeechTrain(1:L);

それぞれ 4 kHz でサンプリングされた、男性話者と女性話者の約 20 秒の音声から成る検証信号を読み取ります。同じ長さになるように、検証信号をトリミングします。

mSpeechValidate = audioread(fullfile(dataset,"MaleSpeech-16-4-mono-20secs.wav"));
fSpeechValidate = audioread(fullfile(dataset,"FemaleSpeech-16-4-mono-20secs.wav"));

L = min(length(mSpeechValidate),length(fSpeechValidate));  
mSpeechValidate = mSpeechValidate(1:L);
fSpeechValidate = fSpeechValidate(1:L);

学習信号をスケーリングして同じ強度にします。検証信号をスケーリングして同じ強度にします。

mSpeechTrain = mSpeechTrain/norm(mSpeechTrain);
fSpeechTrain = fSpeechTrain/norm(fSpeechTrain);
ampAdj = max(abs([mSpeechTrain;fSpeechTrain]));

mSpeechTrain = mSpeechTrain/ampAdj;
fSpeechTrain = fSpeechTrain/ampAdj;

mSpeechValidate = mSpeechValidate/norm(mSpeechValidate);
fSpeechValidate = fSpeechValidate/norm(fSpeechValidate);
ampAdj = max(abs([mSpeechValidate;fSpeechValidate]));

mSpeechValidate = mSpeechValidate/ampAdj;
fSpeechValidate = fSpeechValidate/ampAdj;

学習用と検証用の "カクテル パーティ" 混合音声を作成します。

mixTrain = mSpeechTrain + fSpeechTrain;
mixTrain = mixTrain/max(mixTrain);

mixValidate = mSpeechValidate + fSpeechValidate;
mixValidate = mixValidate/max(mixValidate);

学習用の STFT を生成します。

windowLength = 128;
fftLength = 128;
overlapLength = 128-1;
Fs = 4000;
win = hann(windowLength,"periodic");

P_mix0 = abs(stft(mixTrain,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided"));
P_M = abs(stft(mSpeechTrain,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided"));
P_F = abs(stft(fSpeechTrain,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided"));

混合音声の STFT の対数を取ります。平均と標準偏差を使用して、値を正規化します。

P_mix = log(P_mix0 + eps);
MP = mean(P_mix(:));
SP = std(P_mix(:));
P_mix = (P_mix - MP)/SP;

検証用の STFT を生成します。混合音声の STFT の対数を取ります。平均と標準偏差を使用して、値を正規化します。

P_Val_mix0 = stft(mixValidate,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided");
P_Val_M = abs(stft(mSpeechValidate,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided"));
P_Val_F = abs(stft(fSpeechValidate,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided"));

P_Val_mix = log(abs(P_Val_mix0) + eps);
MP = mean(P_Val_mix(:));
SP = std(P_Val_mix(:));
P_Val_mix = (P_Val_mix - MP) / SP;

ネットワークへの入力が適度に滑らかに分布していて正規化されている場合、ニューラル ネットワークの学習は最も簡単になります。データ分布が滑らかであることを確認するために、学習データの STFT の値のヒストグラムをプロットします。

figure
histogram(P_mix,EdgeColor="none",Normalization="pdf")
xlabel("Input Value")
ylabel("Probability Density")

Figure contains an axes object. The axes object with xlabel Input Value, ylabel Probability Density contains an object of type histogram.

学習用のソフト マスクを計算します。ネットワークの学習では、このマスクをターゲット信号として使用します。

maskTrain = P_M./(P_M + P_F + eps);

検証用のソフト マスクを計算します。このマスクを使用して、学習済みのネットワークによって出力されるマスクを評価します。

maskValidate = P_Val_M./(P_Val_M + P_Val_F + eps);

ターゲットのデータ分布が滑らかであることを確認するために、学習データのマスクの値のヒストグラムをプロットします。

figure

histogram(maskTrain,EdgeColor="none",Normalization="pdf")
xlabel("Input Value")
ylabel("Probability Density")

Figure contains an axes object. The axes object with xlabel Input Value, ylabel Probability Density contains an object of type histogram.

予測子とターゲットの信号から、サイズが (65, 20) のチャンクを作成します。さらに多くの学習サンプルを取得するために、連続するチャンクを 10 セグメント分オーバーラップさせます。

seqLen = 20;
seqOverlap = 10;
mixSequences = zeros(1 + fftLength/2,seqLen,1,0);
maskSequences = zeros(1 + fftLength/2,seqLen,1,0);

loc = 1;
while loc < size(P_mix,2) - seqLen
    mixSequences(:,:,:,end+1) = P_mix(:,loc:loc+seqLen-1);
    maskSequences(:,:,:,end+1) = maskTrain(:,loc:loc+seqLen-1);
    loc = loc + seqOverlap;
end

評価用の予測子とターゲットの信号から、サイズが (65,20) のチャンクを作成します。

mixValSequences = zeros(1 + fftLength/2,seqLen,1,0);
maskValSequences = zeros(1 + fftLength/2,seqLen,1,0);
seqOverlap = seqLen;

loc = 1;
while loc < size(P_Val_mix,2) - seqLen
    mixValSequences(:,:,:,end+1) = P_Val_mix(:,loc:loc+seqLen-1);
    maskValSequences(:,:,:,end+1) = maskValidate(:,loc:loc+seqLen-1);
    loc = loc + seqOverlap;
end

学習信号と検証信号の形状を変更します。

mixSequencesT = reshape(mixSequences,[1 1 (1 + fftLength/2)*seqLen size(mixSequences,4)]);
mixSequencesV = reshape(mixValSequences,[1 1 (1 + fftLength/2)*seqLen size(mixValSequences,4)]);
maskSequencesT = reshape(maskSequences,[1 1 (1 + fftLength/2)*seqLen size(maskSequences,4)]);
maskSequencesV = reshape(maskValSequences,[1 1 (1 + fftLength/2)*seqLen size(maskValSequences,4)]);

深層学習ネットワークの定義

ネットワークの層を定義します。サイズが 1 x 1 x 1300 のイメージになるように入力サイズを指定します。2 つの隠れ全結合層を、それぞれ 1300 ニューロンで定義します。それぞれの隠れ全結合層の後にバイアス付きシグモイド層を続けます。バッチ正規化層は、出力の平均と標準偏差を正規化します。1300 ニューロンをもつ全結合層を追加します。

numNodes = (1 + fftLength/2)*seqLen;

layers = [ ...
    
    imageInputLayer([1 1 (1 + fftLength/2)*seqLen],Normalization="None")
    
    fullyConnectedLayer(numNodes)
    BiasedSigmoidLayer(6)
    batchNormalizationLayer
    dropoutLayer(0.1)

    fullyConnectedLayer(numNodes)
    BiasedSigmoidLayer(6)
    batchNormalizationLayer
    dropoutLayer(0.1)

    fullyConnectedLayer(numNodes)
    BiasedSigmoidLayer(0)
    
    ];

ネットワークの学習オプションを指定します。ネットワークが学習データから 3 個のパスを作るよう MaxEpochs3 に設定します。ネットワークが一度に 64 の学習信号を確認するよう MiniBatchSize64 に設定します。Plotstraining-progress に設定し、反復回数の増大に応じた学習の進行状況を示すプロットを生成します。Verbosefalse に設定し、プロットで示されるデータに対応する表出力のコマンド ライン ウィンドウへの表示を無効にします。Shuffleevery-epoch に設定し、各エポックの最初に学習シーケンスをシャッフルします。LearnRateSchedulepiecewise に設定し、特定のエポック数 (1) が経過するたびに、指定された係数 (0.1) によって学習率を減らします。ValidationData を検証予測子とターゲットに設定します。検証の平均二乗誤差がエポックごとに 1 回計算されるように ValidationFrequency を設定します。この例では、適応モーメント推定 (ADAM) ソルバーを使用します。

maxEpochs = 3;
miniBatchSize = 64;

options = trainingOptions("adam", ...
    MaxEpochs=maxEpochs, ...
    MiniBatchSize=miniBatchSize, ...
    SequenceLength="longest", ...
    Shuffle="every-epoch", ...
    Verbose=0, ...
    Plots="training-progress", ...
    ValidationFrequency=floor(size(mixSequencesT,4)/miniBatchSize), ...
    ValidationData={mixSequencesV,permute(maskSequencesV,[4 3 1 2])}, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.9, ...
    LearnRateDropPeriod=1);

深層学習ネットワークの学習

trainnet を使用し、指定した学習オプションと層のアーキテクチャでネットワークに学習させます。学習セットが大きいため、学習プロセスには数分かかる場合があります。事前学習済みのネットワークを読み込むため、speedupExampletrue に設定します。

speedupExample = false;
if ~speedupExample
    lossFcn = @(Y,T)0.5*l2loss(Y,T,NormalizationFactor="batch-size");
    CocktailPartyNet = trainnet(mixSequencesT,permute(maskSequencesT,[4 3 1 2]),layers,lossFcn,options);
else
    s = load(fullfile(dataset,"CocktailPartyNet.mat"));
    CocktailPartyNet = s.CocktailPartyNet;
end

検証予測子をネットワークに渡します。出力は推定されたマスクです。推定されたマスクの形状を変更します。

estimatedMasks0 = predict(CocktailPartyNet,mixSequencesV);

estimatedMasks0 = estimatedMasks0.';
estimatedMasks0 = reshape(estimatedMasks0,1 + fftLength/2,numel(estimatedMasks0)/(1 + fftLength/2));

深層学習ネットワークの評価

実際のマスクと期待されるマスクとの誤差のヒストグラムをプロットします。

figure
histogram(maskValSequences(:) - estimatedMasks0(:),EdgeColor="none",Normalization="pdf")
xlabel("Mask Error")
ylabel("Probability Density")

Figure contains an axes object. The axes object with xlabel Mask Error, ylabel Probability Density contains an object of type histogram.

ソフト マスクの推定の評価

男性と女性のソフト マスクを推定します。ソフト マスクをしきい値処理し、男性と女性のバイナリ マスクを推定します。

SoftMaleMask = estimatedMasks0; 
SoftFemaleMask = 1 - SoftMaleMask;

推定されたマスクのサイズに一致するよう混合音声の STFT を短縮します。

P_Val_mix0 = P_Val_mix0(:,1:size(SoftMaleMask,2));

混合音声の STFT と男性のソフト マスクを乗算し、男性音声の推定 STFT を取得します。

P_Male = P_Val_mix0.*SoftMaleMask;

ISTFT を使用して、推定された男性オーディオ信号を取得します。オーディオをスケーリングします。

maleSpeech_est_soft = istft(P_Male,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided",ConjugateSymmetric=true);
maleSpeech_est_soft = maleSpeech_est_soft/max(abs(maleSpeech_est_soft));

分析範囲および関連する時間ベクトルを決定します。

range = windowLength:numel(maleSpeech_est_soft)-windowLength;
t = range*(1/Fs);

推定された男性信号と元の男性信号を可視化します。ソフト マスクで推定された男性音声を再生します。

sound(maleSpeech_est_soft(range),Fs)

figure
tiledlayout(2,1)

nexttile
plot(t,mSpeechValidate(range))
title("Original Male Speech")
xlabel("Time (s)")
grid on

nexttile
plot(t,maleSpeech_est_soft(range))
xlabel("Time (s)")
title("Estimated Male Speech (Soft Mask)")
grid on

Figure contains 2 axes objects. Axes object 1 with title Original Male Speech, xlabel Time (s) contains an object of type line. Axes object 2 with title Estimated Male Speech (Soft Mask), xlabel Time (s) contains an object of type line.

混合音声の STFT と女性のソフト マスクを乗算し、女性音声の推定 STFT を取得します。ISTFT を使用して、推定された男性オーディオ信号を取得します。オーディオをスケーリングします。

P_Female = P_Val_mix0.*SoftFemaleMask;

femaleSpeech_est_soft = istft(P_Female,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided",ConjugateSymmetric=true);
femaleSpeech_est_soft = femaleSpeech_est_soft/max(femaleSpeech_est_soft);

推定された女性信号と元の女性信号を可視化します。推定された女性音声を再生します。

sound(femaleSpeech_est_soft(range),Fs)

figure
tiledlayout(2,1)

nexttile
plot(t,fSpeechValidate(range))
title("Original Female Speech")
grid on

nexttile
plot(t,femaleSpeech_est_soft(range))
xlabel("Time (s)")
title("Estimated Female Speech (Soft Mask)")
grid on

Figure contains 2 axes objects. Axes object 1 with title Original Female Speech contains an object of type line. Axes object 2 with title Estimated Female Speech (Soft Mask), xlabel Time (s) contains an object of type line.

バイナリ マスク推定結果の評価

ソフト マスクをしきい値処理し、男性と女性のバイナリ マスクを推定します。

HardMaleMask = SoftMaleMask >= 0.5;
HardFemaleMask = SoftMaleMask < 0.5;

混合音声の STFT と男性のバイナリ マスクを乗算し、男性音声の推定 STFT を取得します。ISTFT を使用して、推定された男性オーディオ信号を取得します。オーディオをスケーリングします。

P_Male = P_Val_mix0.*HardMaleMask;

maleSpeech_est_hard = istft(P_Male,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided",ConjugateSymmetric=true);
maleSpeech_est_hard = maleSpeech_est_hard/max(maleSpeech_est_hard);

推定された男性信号と元の男性信号を可視化します。バイナリ マスクで推定された男性音声を再生します。

sound(maleSpeech_est_hard(range),Fs)

figure
tiledlayout(2,1)

nexttile
plot(t,mSpeechValidate(range))
title("Original Male Speech")
grid on

nexttile
plot(t,maleSpeech_est_hard(range))
xlabel("Time (s)")
title("Estimated Male Speech (Binary Mask)")
grid on

Figure contains 2 axes objects. Axes object 1 with title Original Male Speech contains an object of type line. Axes object 2 with title Estimated Male Speech (Binary Mask), xlabel Time (s) contains an object of type line.

混合音声の STFT と女性のバイナリ マスクを乗算し、男性音声の推定 STFT を取得します。ISTFT を使用して、推定された男性オーディオ信号を取得します。オーディオをスケーリングします。

P_Female = P_Val_mix0.*HardFemaleMask;

femaleSpeech_est_hard = istft(P_Female,Window=win,OverlapLength=overlapLength,FFTLength=fftLength,FrequencyRange="onesided",ConjugateSymmetric=true);
femaleSpeech_est_hard = femaleSpeech_est_hard/max(femaleSpeech_est_hard);

推定された女性音声信号と元の女性音声信号を可視化します。推定された女性音声を再生します。

sound(femaleSpeech_est_hard(range),Fs)

figure
tiledlayout(2,1)

nexttile
plot(t,fSpeechValidate(range))
title("Original Female Speech")
grid on

nexttile
plot(t,femaleSpeech_est_hard(range))
title("Estimated Female Speech (Binary Mask)")
grid on

Figure contains 2 axes objects. Axes object 1 with title Original Female Speech contains an object of type line. Axes object 2 with title Estimated Female Speech (Binary Mask) contains an object of type line.

混合音声、元の女性音声と男性音声、推定された女性音声と男性音声について、それぞれ 1 秒間のセグメントの STFT を比較します。

range = 7e4:7.4e4;

figure
stft(mixValidate(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Mix STFT")

Figure contains an axes object. The axes object with title Mix STFT, xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image.

figure
tiledlayout(3,1)

nexttile
stft(mSpeechValidate(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Male STFT (Actual)")

nexttile
stft(maleSpeech_est_soft(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Male STFT (Estimated - Soft Mask)")

nexttile
stft(maleSpeech_est_hard(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Male STFT (Estimated - Binary Mask)");

Figure contains 3 axes objects. Axes object 1 with title Male STFT (Actual), xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image. Axes object 2 with title Male STFT (Estimated - Soft Mask), xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image. Axes object 3 with title Male STFT (Estimated - Binary Mask), xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image.

figure
tiledlayout(3,1)

nexttile
stft(fSpeechValidate(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Female STFT (Actual)")

nexttile
stft(femaleSpeech_est_soft(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Female STFT (Estimated - Soft Mask)")

nexttile
stft(femaleSpeech_est_hard(range),Fs,Window=win,OverlapLength=64,FFTLength=fftLength,FrequencyRange="onesided");
title("Female STFT (Estimated - Binary Mask)")

Figure contains 3 axes objects. Axes object 1 with title Female STFT (Actual), xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image. Axes object 2 with title Female STFT (Estimated - Soft Mask), xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image. Axes object 3 with title Female STFT (Estimated - Binary Mask), xlabel Time (ms), ylabel Frequency (kHz) contains an object of type image.

参考文献

[1] "Probabilistic Binary-Mask Cocktail-Party Source Separation in a Convolutional Deep Neural Network", Andrew J.R. Simpson, 2015.

参考

| |

トピック