メインコンテンツ

エンドツーエンドの話者分離モデルの学習

話者分離は、困難な音声処理タスクではありますが重要です。ダウンストリーム アプリケーションには、音声認識、話者ダイアライゼーション、話者識別、音声再生強化などがあります。この例では、話者に依存しない音声分離のために、エンドツーエンドの深層学習ネットワークに学習させます。

事前学習済みの話者分離モデルを使用するには、separateSpeakers (Audio Toolbox)を参照してください。

さまざまなモデルのパフォーマンスの比較については、Compare Speaker Separation Models (Audio Toolbox)を参照してください。

事前学習済みモデルを使用した話者分離の実行

Conv-TasNet アーキテクチャ [1] は、エンドツーエンドの音声分離における初期段階の重要な発展でした。話者分離メトリクスではトランスフォーマーベースのモデルに追い抜かれているものの、その速度や小さなサイズ、学習の容易さから、依然として重要なモデルとなっています。これらの特性により、Conv-TasNet は特定のノイズ条件下での展開を想定した話者分離ソリューションの候補としても適しています。

事前学習済みの Conv-TasNet モデルやトランスフォーマーベースのモデルにアクセスするには、separateSpeakers 関数を使用します。通常、話者分離は音声信号を一定数の話者に分離するタスクとして扱われます。一方で、現実のシナリオでは、話者の人数が分からないことが多く、変化することもあります。separateSpeakers が使用する Conv-TasNet モデルは、常に最も支配的な話者を分離し、1 人の話者の信号と "それ以外" (残差とも呼ばれる) を出力することを試みる、one-and-rest 目的関数で学習させています [2]。複数の話者を分離するには、残差で話者が検出されなくなるまで、残差を再帰的にモデルにフィードバックできます。

サンプル オーディオ ファイルを読み取り、チャネルをミキシングして結果を再生します。

[targets,fs] = audioread('MultipleSpeakers-16-8-4channel-5secs.flac');
x = mean(targets,2);
sound(x,fs),pause(size(x,1)/fs + 1)

既定の one-and-rest Conv-TasNet モデルを使用して話者を分離します。結果を可視化するには、出力なしで関数をもう一度呼び出します。結果も再生します。

[y,r] = separateSpeakers(x,fs);
separateSpeakers(x,fs)

for ii = 1:size(y,2)
    sound(y(:,ii),fs),pause(size(x,1)/fs + 1)
end
sound(r,fs)

順列不変スケール不変 S/N 比 (SI-SNR) を評価するには、permutationInvariantSISNR 関数を使用します。順列不変 SI-SNR メトリクスは、話者分離の最も一般的な評価メトリクスおよび学習メトリクスです。

permutationInvariantSISNR(y,targets)
ans = single

6.1334

この例の既定のパラメーターによって発話レベル順列不変学習 (uPIT) パイプラインが定義されており、separateSpeakers を使ってアクセス可能な事前学習済みのモデルと同様の、2 者音声分離モデルに学習させることができます。この例にはパラメーターが用意されており、これを変更することで、one-and-rest 順列不変学習 (OR-PIT) を指定して音声を 1 人とそれ以外 (one-and-rest) に分離したり、システムにノイズを含めたり除外したりすることができます。

削減されたデータ セットで実行して例を高速化するには、speedupExampletrue に設定します。公開されている例をすべて実行するには、speedupExamplefalse に設定します。

speedupExample = true;

データのダウンロード

学習データと検証データ

例を高速化する場合、学習と開発の両方のために、LibriSpeech の dev-clean セットがダウンロードされ、解凍されます。例を高速化しない場合、例では LibriSpeech の train-clean-100 サブセットが使用され、検証には dev-clean セットのみが使用されます。

downloadDatasetFolder = tempdir;

trainDatasetFolder = {};
devDatasetFolder = {};

if speedupExample
    trainDatasetFolder{end+1} = char(downloadAndUnzip("dev-clean",downloadDatasetFolder));
else
    devDatasetFolder{end+1} = char(downloadAndUnzip("dev-clean",downloadDatasetFolder));
    trainDatasetFolder{end+1} = char(downloadAndUnzip("train-clean-100",downloadDatasetFolder));

    % Uncomment following lines to train on entire LibriSpeech dataset
    %trainDatasetFolder{end+1} = char(downloadAndUnzip("train-clean-360",downloadDatasetFolder));
    %trainDatasetFolder{end+1} = char(downloadAndUnzip("train-other-500",downloadDatasetFolder));
end

adsTrain = audioDatastore(trainDatasetFolder,IncludeSubfolders=true,OutputDataType="single");

if speedupExample
    adsTrain = shuffle(adsTrain);
    adsValidation = subset(adsTrain,1:20);
    adsTrain = subset(adsTrain,21:200);
else
    adsValidation = audioDatastore(devDatasetFolder,IncludeSubfolders=true,OutputDataType="single");
end

ノイズ データ

バックグラウンド ノイズ ファイルを含むデータストアを作成します。

noiseDatasetFolder = "WashingMachine-16-8-mono-1000secs.mp3";
adsNoise = audioDatastore(noiseDatasetFolder,IncludeSubfolders=true,OutputDataType="single");

前処理パイプラインの定義

オーディオのリサンプリング

学習データと検証データを 8 kHz にリサンプリングするための変換を定義します。Conv-TasNet はオーディオを直接受け入れるため、メモリ不足の問題を回避するために低いサンプル レートを使用するのが一般的です。

fs = 8e3;
adsTrain = transform(adsTrain,@(x,info)resampleTransform(x,info,fs),IncludeInfo=true);
adsValidation = transform(adsValidation,@(x,info)resampleTransform(x,info,fs),IncludeInfo=true);

動的ミキシングを使用したデータの拡張

DynamicMixer オブジェクトを作成します。例を開くと、DynamicMixer クラス定義が現在のフォルダーに配置されます。ノイズ付加の SNR を指定するには、dB 単位で範囲を指定します。ミキシングの際、オブジェクトはその範囲内の一様分布から選択を行います。ターゲット信号にノイズを含めるには、NoiseInTargets'mix' または 'append' として指定します。NoiseInTargets が 'exclude' に設定されている場合、モデルは話者分離を実行すると同時に入力信号のノイズ除去を試みます。NoiseFiles は、ファイル名の cell 配列または信号の cell 配列に設定できます。オーディオ信号出力の MaxLength を設定します。NumSignals プロパティを、各学習例に含めることができる信号の数の範囲に設定します。uPIT 学習の場合、NumSignals はスカラーでなければなりません。

noiseFiles = readall(adsNoise);
mixer = DynamicMixer( ...
    AugFiles=adsTrain.UnderlyingDatastores{1}.Files, ...
    SignalWeightRange=[0.9 1], ...
    NoiseFiles=noiseFiles, ...
    NoiseWeightRange=[20 60], ... % SNR dB
    NoiseInTargets='exclude', ...
    MaxLength=fs *5, ... % move the slider to specify max length in seconds
    NumSignals=[2 2]);

検証セットからサンプルを読み取り、ミキシングを適用します。

[x,xinfo] = read(adsValidation);
xt = mix(mixer,x,xinfo);
x1 = xt{1};
t1 = xt{2};

予測子とターゲット信号を検査します。

tiledlayout(2,1)

timevec = (0:size(x1,1)-1)/fs;

nexttile()
plot(timevec,x1)
ylabel("Mixed Signal")

nexttile()
plot(timevec,t1)
ylabel("Target Signals")

sound(x1,fs),pause(size(x1,1)/fs + 1)
for ii = 1:size(t1,2)
    sound(t1(:,ii),fs),pause(size(t1,1)/fs + 1)
end

ミニバッチ キューの作成

ミニバッチ内のすべてのオーディオ ファイルの長さを統一し、ターゲットのサイズを均一に保つため、minibatchqueueを使用します。preprocessMiniBatch 関数は、信号をバッチの中央値の長さに合わせてパディングまたは切り捨てた後、バッチを cell 配列として DynamicMixer に渡します。DynamicMixer はバッチごとに NumSignals 範囲内の 1 つの値をランダムに選択し、ミニバッチに追加します。これは、バッチ処理でターゲットのサイズを均一に保つために必要です。

mbqTrain = minibatchqueue(adsTrain,2, ...
    OutputAsDlarray=true, ...
    MiniBatchFormat="TCB", ...
    MiniBatchFcn=@(m)preprocessMiniBatch(m,mixer,fs));

mbqValidation = minibatchqueue(adsValidation,2, ...
    OutputAsDlarray=true, ...
    MiniBatchFormat="TCB", ...
    MiniBatchFcn=@(m)preprocessMiniBatch(m,mixer,fs));

モデルの定義

Conv-TasNet モデルの作成

サポート関数 convtasnet を使用し、[1] に基づいてランダムに初期化された Conv-TasNet モデルを構築します。ここで公開されているハイパーパラメーターの詳細と解析については、[1] を参照してください。モデルを構築するには、この例を開いたときに現在のフォルダーに配置されるいくつかのカスタム層クラス定義が必要です。

net = convtasnet( ...
    NumFiltersAutoEncoder=512, ... N  - Number of filters in autoencoder
    FilterLength=16, ...            L  - Length of filters (in samples)
    NumChannelsBottleneck=128, ... B  - Number of channels in bottleneck
    NumChannelsConvBlocks=512, ... H  - Number of channels in convolutional blocks
    NumChannelsSkipPaths= 128, ... Sc - Number of channels in skip connection paths' 1x1-conv blocks
    KernelSizeConvBlocks=3, ...     P  - Kernel size in convolution blocks
    NumConvBlocksInRepeat=8, ...   X  - Number of convolutional blocks in each repeat
    NumRepeats=3, ...              R  - Number of repeats
    NumResponses=2);

モデルを解析するには、analyzeNetworkを使用します。

analyzeNetwork(net)

学習オプションの定義

学習オプションを定義するには、trainingOptionsを使用します。

mbs = 6;
options = trainingOptions("adam", ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    MaxEpochs=10, ...
    MiniBatchSize=mbs, ...
    TargetDataFormats="TCB", ...
    Verbose=false, ...
    GradientThreshold=5, ...
    InitialLearnRate=0.00015, ...
    OutputNetwork="best-validation", ...
    ExecutionEnvironment="gpu", ...
    ValidationData=mbqValidation, ...
    ValidationFrequency=2*round(numel(adsTrain.UnderlyingDatastores{1}.Files)/mbs), ...
    ValidationPatience=3);

モデルの学習

モデルに学習させるには、trainnetを使用します。permutationInvariantSISNR (Audio Toolbox)の負値を損失関数として使用します。PermutationType は、発話レベル順列不変学習 (uPIT) または one-and-rest 順列不変学習 (OR-PIT) として設定できます。動的混合器オブジェクトの NumSignals パラメーターが範囲である場合、PermutationType"OR-PIT" に設定しなければなりません。

[net,traininfo] = trainnet(mbqTrain,net, ...
    @(Y,T)-permutationInvariantSISNR(Y,T,PermutationType="uPIT"), ...
    options);

モデルの評価

検証セットからサンプルを読み取り、ミキシングを適用します。

[x,xinfo] = read(adsValidation);
xt = mix(mixer,x,xinfo);
predictor = xt{1};
targets = xt{2};

ミックスド シグナルをモデルに入力します。

y = (predict(net,predictor))';

順列不変 SI-SNR を評価し、話者信号の最適な順列のインデックスを取得します。

[metric,idx] = permutationInvariantSISNR(y,targets);
display("Permutation Invariant SI-SINR: " + metric + " dB")
    "Permutation Invariant SI-SINR: 15.1941 dB"

予測を最適な順列に並べ替えます。

y = y(:,idx);

順列不変 SI-SNR 改善 (SI-SNRi) を計算するには、sisnr (Audio Toolbox)を使用します。

ametric = sisnr(predictor,targets);
bmetric = sisnr(y,targets);
metricImprovement = mean(bmetric - ametric);
display("Permutation Invariant SI-SNRi: " + metricImprovement + " dB")
    "Permutation Invariant SI-SNRi: 15.21 dB"

分離の結果を再生し、プロットします。

sound(predictor,fs),pause(size(predictor,1)/fs+1)
y = y./max(abs(y));
sound(y(:,1),fs),pause(size(y,1)/fs+1)
sound(y(:,2),fs)

t = (0:size(predictor,1)-1)/fs;

tiledlayout(3,1)

nexttile()
plot(t,predictor)
ylabel('Input Signal')

nexttile()
plot(t,targets(:,1),'b',t,y(:,1),'r:')
ylabel('Signal 1')
legend('target','prediction')

nexttile()
plot(t,targets(:,2),'b',t,y(:,2),'r:')
ylabel('Signal 2')
legend('target','prediction')

サポート関数

リサンプリング変換

function [audioOut,adsInfo] = resampleTransform(audioIn,adsInfo,desiredFs)
    fs = adsInfo.SampleRate;

    % Convert to mono
    x = mean(audioIn,2);
    
    % Resample to 16 kHz
    if desiredFs~=fs
        x = resample(x,desiredFs,fs);
    end
    adsInfo.SampleRate = desiredFs;
    
    % Normalize so that the max absolute value of a signal is 1 
    audioOut = x/max(abs(x));
end

Conv-TasNet モデルの作成

function net = convtasnet(options)
% This model and its parameters are described in: 
% [1] Luo, Yi, and Nima Mesgarani. "Conv-TasNet: Surpassing Ideal
% Time–Frequency Magnitude Masking for Speech Separation." IEEE/ACM
% Transactions on Audio, Speech, and Language Processing, vol. 27, no. 8,
% Aug. 2019, pp. 1256–66. DOI.org (Crossref),
% https://doi.org/10.1109/TASLP.2019.2915167.
arguments
    options.NumFiltersAutoEncoder = 512   % N  - Number of filters in autoencoder
    options.FilterLength = 16             % L  - Length of filters (in samples)
    options.NumChannelsBottleneck = 128   % B  - Number of channels in bottleneck
    options.NumChannelsConvBlocks = 512   % H  - Number of channels in convolutional blocks
    options.NumChannelsSkipPaths = 128    % Sc - Number of channels in skip connection paths' 1x1-conv blocks
    options.KernelSizeConvBlocks = 3      % P  - Kernel size in convolution blocks
    options.NumConvBlocksInRepeat = 8     % X  - Number of convolutional blocks in each repeat
    options.NumRepeats = 3                % R  - Number of repeats
    options.NumResponses = 2              % Number of responses
end

N = options.NumFiltersAutoEncoder;
L = options.FilterLength;
B = options.NumChannelsBottleneck;
Sc = options.NumChannelsSkipPaths;
H = options.NumChannelsConvBlocks;
P = options.KernelSizeConvBlocks;
X = options.NumConvBlocksInRepeat; % Also referred to as "M" in [1].
R = options.NumRepeats;
numResponses = options.NumResponses;


net = dlnetwork();

% Encoder + Bottleneck -----------------------------------------------------
net = addLayers(net,[ ...
    sequenceInputLayer(1,MinLength=L,Name="input")
    convolution1dLayer(L,N,Stride=L/2,Name="encoder.conv1d")
    layerNormalizationLayer(OperationDimension="batch-excluded",Name="bottleneck.layernorm")
    convolution1dLayer(1,B,Name="bottleneck.conv1d")]);


% TCN ----------------------------------------------------------------------
previousLayer = "bottleneck.conv1d";
dilationFactor = repmat((2.^(0:X-1)),1,R);
net = addLayers(net,additionLayer(X*R,Name="skip_connection"));

for ii = 1:(X*R)

    net = addLayers(net,[ ...
        convolution1dLayer(1,H,Name="TCN."+ii+".conv1d")
        preluLayer(Name="TCN."+ii+"prelu_1")
        layerNormalizationLayer(OperationDimension="batch-excluded",Name="TCN."+ii+".layernorm_1")

        groupedConvolution1dLayer(FilterSize=P,DilationFactor=dilationFactor(ii),Name="TCN."+ii+".groupconv1d")
        preluLayer(Name="TCN."+ii+".prelu_2")
        layerNormalizationLayer(OperationDimension="batch-excluded",Name="TCN."+ii+".layernorm_2")

        convolution1dLayer(1,Sc,Name="TCN."+ii+".skip.conv1d")
        ]);

    if ii~=(X*R)
        net = addLayers(net,[ ...
            convolution1dLayer(1,B,Name="TCN."+ii+".residual.conv1d") ...
            additionLayer(2,Name="TCN."+ii+".output+residual")]);

        net = connectLayers(net,"TCN."+ii+".layernorm_2","TCN."+ii+".residual.conv1d");
        net = connectLayers(net,previousLayer,"TCN."+ii+".output+residual"+"/in2");
    end

    net = connectLayers(net,previousLayer,"TCN."+ii+".conv1d");

    previousLayer = "TCN."+ii+".output+residual";

    net = connectLayers(net,"TCN."+(ii)+".skip.conv1d","skip_connection/in"+ii);
end


% Output -------------------------------------------------------------------
net = addLayers(net,[ ...
    preluLayer(Name="mask.prelu")
    convolution1dLayer(1,numResponses*N,Name="mask.conv1d")
    splitMaskLayer(NumSpeakers=numResponses,Name="splitMask")
    reluLayer(Name='relu') ... % The original paper uses sigmoid. Common practice now is ReLU, but results have not been compared.
    ]);
net = connectLayers(net,"skip_connection","mask.prelu");

% Apply Mask
net = addLayers(net,[ ...
    addDimLayer(Name="viewLayer_2",DimToAdd="S")
    multiplicationLayer(2,Name="apply_mask")
    reformatLayer(Name="viewLayer_3",Format="TCBS")
    transposedConv1dLayer(L,1,Stride=L/2,Name='return')
    reformatLayer(Name="viewLayer_4",Format="TCBS")
    removeDimLayer(Name="formatOutput",DimToRemove='C')]);

net = connectLayers(net,"encoder.conv1d","viewLayer_2");
net = connectLayers(net,"relu","apply_mask/in2");

net = initialize(net,dlarray(zeros(L,1,'single'),'TC'));

end

前処理ミニバッチ

function [X,T] = preprocessMiniBatch(Xin,mixer,fs)
% preprocessMiniBatch - Preprocess mini-batch
info.SampleRate = fs;

% Get the median signal length
signalLengths = cellfun(@(x)size(x,1),Xin,UniformOutput=false);
medianSignalLength = min(round(median(cat(1,signalLengths{:}))),mixer.MaxLength);

% Trim or pad signals to same length
Xin = cellfun(@(x)trimOrPad(x,medianSignalLength),Xin,UniformOutput=false);

% Apply mixing
dataOut = mixer.mix(Xin,info);

% Output the predictors and targets separately, with batches along the
% third dimension
X = cat(3,dataOut{:,1});
T = cat(3,dataOut{:,2});

end

LibriSpeech のダウンロード

function datasetFolder = downloadAndUnzip(dataset,downloadDatasetFolder)
filename = dataset + ".tar.gz";
datasetFolder = fullfile(downloadDatasetFolder,"LibriSpeech",dataset);
url = "http://us.openslr.org/resources/12/" + filename;
if ~datasetExists(datasetFolder)
    gunzip(url,downloadDatasetFolder);
    unzippedFile = fullfile(downloadDatasetFolder,filename);
    untar(unzippedFile{1}(1:end-3),downloadDatasetFolder);
end
end

参考文献

[1] Luo, Yi, and Nima Mesgarani. "Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation." IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 27, no. 8, Aug. 2019, pp. 1256–66. DOI.org (Crossref), https://doi.org/10.1109/TASLP.2019.2915167.

[2] Takahashi, Naoya, et al. "Recursive Speech Separation for Unknown Number of Speakers." Interspeech 2019, ISCA, 2019, pp. 1348–52. DOI.org (Crossref), https://doi.org/10.21437/Interspeech.2019-1550.

Copyright 2024 The MathWorks, Inc.