Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

MFCC ネットワークと LSTM ネットワークを使用したノイズ内のキーワード スポッティング

この例では、深層学習ネットワークを使用してノイズを含む音声からキーワードを特定する方法を説明します。特に、この例では、双方向長短期記憶 (BiLSTM) ネットワークとメル周波数ケプストラム係数 (MFCC) を使用します。

はじめに

キーワード スポッティング (KWS) は、音声アシスト技術に欠かせない要素です。音声アシストでは、ユーザーがデバイスへのコマンドや質問を発話する前に、事前に定義されたキーワードを言うことでシステムを起動します。

この例では、メル周波数ケプストラム係数 (MFCC) の特徴シーケンスを使用して KWS 深層ネットワークに学習させます。また、データ拡張を使用することでノイズを含む環境におけるネットワークの精度がどのように改善されるかについても説明します。

この例では、シーケンスおよび時系列のデータの学習に適した再帰型ニューラル ネットワーク (RNN) の一種、長短期記憶 (LSTM) ネットワークを使用します。LSTM ネットワークは、シーケンスのタイム ステップ間の長期的な依存関係を学習できます。LSTM 層 (lstmLayer) は、順方向の時間系列を確認し、一方、双方向の LSTM 層 (bilstmLayer) は、順方向および逆方向両方の時間系列を確認できます。この例では、双方向の LSTM 層を使用します。

この例では、Google Speech Commands データセットを使用して深層学習モデルに学習させます。この例を実行するには、まず、データセットをダウンロードしなければなりません。データ セットをダウンロードしない、またはネットワークに学習させない場合、この例を MATLAB® で開いて "事前学習済みのネットワークを使用したキーワードのスポッティング" の節を実行することで、事前学習済みのネットワークをダウンロードして使用できます。

事前学習済みのネットワークを使用したキーワードのスポッティング

学習プロセスの詳細について説明する前に、事前学習済みのキーワード スポッティング ネットワークをダウンロードし、このネットワークを使用してキーワードを識別します。

この例では、特定するキーワードは YES です。

キーワードが発話されたテスト信号を読み取ります。

[audioIn,fs] = audioread("keywordTestSignal.wav");
sound(audioIn,fs)

事前学習済みのネットワーク、特徴の正規化に使用する平均値 (M) および標準偏差 (S) のベクトル、およびこの例の後半でネットワークの検証に使用する 2 つのオーディオ ファイルをダウンロードして読み込みます。

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","KeywordSpotting.zip");
dataFolder = tempdir;
unzip(downloadFolder,dataFolder)
netFolder = fullfile(dataFolder,"KeywordSpotting");
load(fullfile(netFolder,"KWSNet.mat"));

audioFeatureExtractor (Audio Toolbox)オブジェクトを作成して特徴抽出を実行します。

windowLength = 512;
overlapLength = 384;
afe = audioFeatureExtractor(SampleRate=fs, ...
    Window=hann(windowLength,"periodic"),OverlapLength=overlapLength, ...
    mfcc=true,mfccDelta=true,mfccDeltaDelta=true);

テスト信号から特徴を抽出し、それらを正規化します。

features = extract(afe,audioIn);

features = (features - M)./S;

キーワード スポッティング バイナリ マスクを計算します。マスク値 1 は、キーワードがスポッティングされたセグメントに対応します。

mask = classify(KWSNet,features.');

マスク内の各サンプルは、音声信号から取得した 128 個のサンプル (windowLength - overlapLength) に対応します。

マスクを信号の長さに拡張します。

mask = repmat(mask,windowLength-overlapLength,1);
mask = double(mask) - 1;
mask = mask(:);

テスト信号とマスクをプロットします。

figure
audioIn = audioIn(1:length(mask));
t = (0:length(audioIn)-1)/fs;
plot(t,audioIn)
grid on
hold on
plot(t, mask)
legend("Speech","YES")

スポッティングされたキーワードを再生します。

sound(audioIn(mask==1),fs)

マイクからのストリーミング オーディオを使用したコマンドの検出

事前学習済みのコマンド検出ネットワークをマイクからのストリーミング オーディオでテストします。キーワード (YES) を含む、ランダムな単語を発声してみます。

audioFeatureExtractor オブジェクトでgenerateMATLABFunction (Audio Toolbox)を呼び出し、特徴抽出関数を作成します。この関数は、処理ループで使用します。

generateMATLABFunction(afe,"generateKeywordFeatures",IsStreaming=true);

マイクからオーディオを読み取ることができるオーディオ デバイス リーダーを定義します。フレーム長をホップ長に設定します。これにより、マイクからのすべての新しいオーディオ フレームについて、特徴の新しいセットを計算できます。

hopLength = windowLength - overlapLength;
frameLength = hopLength;
adr = audioDeviceReader(SampleRate=fs,SamplesPerFrame=frameLength);

音声信号と推定されたマスクを可視化するためのスコープを作成します。

scope = timescope(SampleRate=fs, ...
    TimeSpanSource="property", ...
    TimeSpan=5, ...
    TimeSpanOverrunAction="Scroll", ...
    BufferLength=fs*5*2, ...
    ShowLegend=true, ...
    ChannelNames={'Speech','Keyword Mask'}, ...
    YLimits=[-1.2,1.2], ...
    Title="Keyword Spotting");

マスクを推定するレートを定義します。numHopsPerUpdate オーディオ フレームごとに 1 回マスクを生成します。

numHopsPerUpdate = 16;

オーディオのバッファーを初期化します。

dataBuff = dsp.AsyncBuffer(windowLength);

計算された特徴のバッファーを初期化します。

featureBuff = dsp.AsyncBuffer(numHopsPerUpdate);

バッファーを初期化し、オーディオおよびマスクのプロットを管理します。

plotBuff = dsp.AsyncBuffer(numHopsPerUpdate*windowLength);

ループ処理を無限に実行するには、timeLimit を Inf に設定します。シミュレーションを停止するには、スコープを閉じます。

timeLimit = 20;

tic
while toc < timeLimit

    data = adr();
    write(dataBuff,data);
    write(plotBuff,data);

    frame = read(dataBuff,windowLength,overlapLength);
    features = generateKeywordFeatures(frame,fs);
    write(featureBuff,features.');

    if featureBuff.NumUnreadSamples == numHopsPerUpdate
        featureMatrix = read(featureBuff);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        featureMatrix = (featureMatrix - M)./S;

        [keywordNet,v] = classifyAndUpdateState(KWSNet,featureMatrix.');
        v = double(v) - 1;
        v = repmat(v,hopLength,1);
        v = v(:);
        v = mode(v);
        v = repmat(v,numHopsPerUpdate*hopLength,1);

        data = read(plotBuff);
        scope([data,v]);

        if ~isVisible(scope)
            break;
        end
    end
end
hide(scope)

この例の残りの部分では、キーワード スポッティング ネットワークに学習させる方法を学習します。

学習プロセスの概要

学習プロセスでは、次の手順を実行します。

  1. 検証信号に対する "黄金律" となるキーワード スポッティングのベースラインを調べます。

  2. ノイズを含まないデータセットから、学習用の発話を作成します。

  3. この発話から抽出した MFCC シーケンスを使用して、キーワード スポッティングの LSTM ネットワークに学習させます。

  4. 検証信号に適用した場合のネットワークの出力を検証のベースラインと比較して、ネットワークの精度をチェックします。

  5. ノイズで破損した検証信号に対するネットワークの精度をチェックします。

  6. audioDataAugmenter (Audio Toolbox) を使用して音声データにノイズを注入し、学習データセットを拡張します。

  7. 拡張したデータセットを使用して、ネットワークに再学習させます。

  8. 再学習させたネットワークを検証し、ノイズを含む検証信号に適用したときの精度が向上していることを確認します。

検証データの検査

サンプルの音声信号を使用して KWS ネットワークを検証します。この検証信号には、YES というキーワードが断続的に出現する 34 秒の音声が含まれています。

検証信号を読み込みます。

[audioIn,fs] = audioread(fullfile(netFolder,"KeywordSpeech-16-16-mono-34secs.flac"));

信号を再生します。

sound(audioIn,fs)

信号を可視化します。

figure
t = (1/fs)*(0:length(audioIn)-1);
plot(t,audioIn);
grid on
xlabel("Time (s)")
title("Validation Speech Signal")

KWS のベースラインの検査

KWS のベースラインを読み込みます。このベースラインは、speech2text信号ラベラー (Signal Processing Toolbox)を使用して取得されたものです。関連する例については、オーディオ信号内の発声された単語のラベル付け (Signal Processing Toolbox)を参照してください。

load("KWSBaseline.mat","KWSBaseline")

ベースラインは、検証用オーディオ信号と同じ長さの logical ベクトルです。キーワードが発話された audioIn のセグメントは、KWSBaseline で 1 に設定されています。

KWS のベースラインに沿って音声信号を可視化します。

fig = figure;
plot(t,[audioIn,KWSBaseline'])
grid on
xlabel("Time (s)")
legend("Speech","KWS Baseline",Location="southeast")
l = findall(fig,"type","line");
l(1).LineWidth = 2;
title("Validation Signal")

キーワードとして識別された音声のセグメントを再生します。

sound(audioIn(KWSBaseline),fs)

学習させるネットワークの目的は、このベースラインのように 0 と 1 から成る KWS マスクを出力することです。

音声コマンド データセットの読み込み

Google Speech Commands データセット ([1]) をダウンロードし、解凍します。

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","google_speech.zip");
dataFolder = tempdir;
unzip(downloadFolder,dataFolder)
dataset = fullfile(dataFolder,"google_speech");

データセットを指す audioDatastore (Audio Toolbox) を作成します。

ads = audioDatastore(dataset,LabelSource="foldername",Includesubfolders=true);
ads = shuffle(ads);

データセットには、この例では使用されないバックグラウンド ノイズのファイルが含まれています。subset (Audio Toolbox) を使用して、バックグラウンド ノイズのファイルを含まない新しいデータストアを作成します。

isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);

このデータセットには、30 個の短い単語 (キーワードの YES を含む) で構成される 1 秒間の発話が約 65,000 個含まれています。データストアに含まれる単語の分布の内訳を取得します。

countEachLabel(ads)
ans=30×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

ads を 2 つのデータストアに分割します。1 番目のデータストアには、キーワードに対応するファイルが含まれます。2 番目のデータストアには、その他の単語がすべて含まれます。

keyword = "yes";
isKeyword = ismember(ads.Labels,keyword);
adsKeyword = subset(ads,isKeyword);
adsOther = subset(ads,~isKeyword);

データセット全体を使ってネットワークに学習させ、できる限り精度を高くするには、speedupExamplefalse に設定します。この例を短時間で実行するには、speedupExampletrue に設定します。

speedupExample = false;
if speedupExample
    % Reduce the dataset by a factor of 20
    adsKeyword = splitEachLabel(adsKeyword,round(numel(adsKeyword.Files)/20));
    numUniqueLabels = numel(unique(adsOther.Labels));
    adsOther = splitEachLabel(adsOther,round(numel(adsOther.Files)/numUniqueLabels/20));
end

各データストアに含まれる単語の分布の内訳を取得します。連続した読み取り操作によって異なる単語が返されるように、adsOther データストアをシャッフルします。

countEachLabel(adsKeyword)
ans=1×2 table
    Label    Count
    _____    _____

     yes     2377 

countEachLabel(adsOther)
ans=29×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

adsOther = shuffle(adsOther);

学習用の文とラベルの作成

学習データストアは、1 つの単語が発話された 1 秒間の音声信号で構成されています。ここでは、キーワードと他の単語が混在する複雑な学習用発話音声を作成します。

以下は、作成された発話の例です。キーワードのデータストアからキーワードを 1 つ読み取り、最大値が 1 となるようにそのキーワードを正規化します。

yes = read(adsKeyword);
yes = yes/max(abs(yes));

この信号には、有意な音声情報をもたない音声以外の部分 (サイレンス、バックグラウンド ノイズなど) が含まれています。この例では、detectSpeech (Audio Toolbox) を使用してサイレンスを除去します。

信号に含まれる有意な部分の開始インデックスと終了インデックスを取得します。

speechIndices = detectSpeech(yes,fs);

合成された学習文で使用する単語の数をランダムに選択します。最大 10 個の単語を使用します。

numWords = randi([0,10]);

キーワードが出現する場所をランダムに選択します。

keywordLocation = randi([1,numWords+1]);

キーワード以外の発話を必要な数だけ読み取り、学習文とマスクを作成します。

sentence = [];
mask = [];
for index = 1:numWords+1
    if index == keywordLocation
        sentence = [sentence;yes]; %#ok
        newMask = zeros(size(yes));
        newMask(speechIndices(1,1):speechIndices(1,2)) = 1;
        mask = [mask;newMask]; %#ok
    else
        other = read(adsOther);
        other = other./max(abs(other));
        sentence = [sentence;other]; %#ok
        mask = [mask;zeros(size(other))]; %#ok
    end
end

学習文をマスクとともにプロットします。

figure
t = (1/fs)*(0:length(sentence)-1);
fig = figure;
plot(t,[sentence,mask])
grid on
xlabel("Time (s)")
legend("Training Signal","Mask",Location="southeast")
l = findall(fig,"type","line");
l(1).LineWidth = 2;
title("Example Utterance")

学習文を再生します。

sound(sentence,fs)

特徴の抽出

この例では、39 個の MFCC 係数 (13 個の MFCC 係数、13 個のデルタ係数、および 13 個のデルタデルタ係数) を使用して深層学習ネットワークに学習させます。

MFCC の抽出に必要なパラメーターを定義します。

windowLength = 512;
overlapLength = 384;

audioFeatureExtractor オブジェクトを作成して特徴抽出を実行します。

afe = audioFeatureExtractor(SampleRate=fs, ...
    Window=hann(windowLength,"periodic"),OverlapLength=overlapLength, ...
    mfcc=true,mfccDelta=true,mfccDeltaDelta=true);

特徴を抽出します。

featureMatrix = extract(afe,sentence);
size(featureMatrix)
ans = 1×2

   478    39

入力に対してウィンドウをスライドさせて MFCC を計算するため、特徴行列は入力の音声信号より短くなることに注意してください。featureMatrix の各行は、音声信号から取得した 128 個のサンプル (windowLength - overlapLength) に対応します。

featureMatrix と同じ長さのマスクを計算します。

hopLength = windowLength - overlapLength;
range = hopLength*(1:size(featureMatrix,1)) + hopLength;
featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(mask((index-1)*hopLength+1:(index-1)*hopLength+windowLength));
end

学習データセットからの特徴の抽出

学習データセット全体に対して文の合成と特徴の抽出を行うと、非常に時間がかかります。Parallel Computing Toolbox™ がある場合、処理を高速化するために、学習データストアを分割して各区画を別々のワーカーで処理します。

データストアの区画数を選択します。

numPartitions = 6;

特徴行列とマスクの cell 配列を初期化します。

TrainingFeatures = {};
TrainingMasks= {};

parfor を使用して、文の合成、特徴の抽出、およびマスクの作成を行います。

emptyCategories = categorical([1 0]);
emptyCategories(:) = [];

tic
parfor ii = 1:numPartitions

    subadsKeyword = partition(adsKeyword,numPartitions,ii);
    subadsOther = partition(adsOther,numPartitions,ii);

    count = 1;
    localFeatures = cell(length(subadsKeyword.Files),1);
    localMasks = cell(length(subadsKeyword.Files),1);

    while hasdata(subadsKeyword)

        % Create a training sentence
        [sentence,mask] = synthesizeSentence(subadsKeyword,subadsOther,fs,windowLength);

        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;

        % Create mask
        range = hopLength*(1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask((index-1)*hopLength+1:(index-1)*hopLength+windowLength));
        end

        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];

        count = count + 1;
    end

    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
Analyzing and transferring files to the workers ...done.
disp("Training feature extraction took " + toc + " seconds.")
Training feature extraction took 41.0509 seconds.

平均が 0、標準偏差が 1 になるよう、すべての特徴を正規化することをお勧めします。各係数の平均と標準偏差を計算し、それらを使用してデータを正規化します。

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if speedupExample
    load(fullfile(netFolder,"keywordNetNoAugmentation.mat"),"keywordNetNoAugmentation","M","S");
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M)./S;
    TrainingFeatures{index} = f.'; %#ok
end

検証特徴の抽出

検証信号から MFCC 特徴を抽出します。

featureMatrix = extract(afe, audioIn);
featureMatrix(~isfinite(featureMatrix)) = 0;

検証特徴を正規化します。

FeaturesValidationClean = (featureMatrix - M)./S;
range = hopLength*(1:size(FeaturesValidationClean,1)) + hopLength;

検証用の KWS マスクを作成します。

featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(KWSBaseline((index-1)*hopLength+1:(index-1)*hopLength+windowLength));
end
BaselineV = categorical(featureMask);

LSTM ネットワーク アーキテクチャの定義

LSTM ネットワークは、シーケンス データのタイム ステップ間の長期的な依存関係を学習できます。この例では、双方向の LSTM 層 bilstmLayer を使用し、シーケンスを順方向および逆方向の両方で確認します。

サイズ numFeatures のシーケンスになるように入力サイズを指定します。出力サイズが 150 である 2 つの双方向の隠れ LSTM 層を指定し、シーケンスを出力します。このコマンドは、双方向の LSTM 層に対し、入力時系列を 150 個の特徴にマッピングするよう指示します。これらの特徴は次の層に渡されます。サイズが 2 の全結合層を含めることによって 2 個のクラスを指定し、その後にソフトマックス層と分類層を配置します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    bilstmLayer(150,OutputMode="sequence")
    bilstmLayer(150,OutputMode="sequence")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ];

学習オプションの定義

分類器の学習オプションを指定します。ネットワークが学習データから 10 個のパスを作るよう MaxEpochs を 10 に設定します。ネットワークが一度に 64 個の学習信号を確認するよう、MiniBatchSize64 に設定します。Plots"training-progress" に設定し、反復回数の増大に応じた学習の進行状況を示すプロットを生成します。Verbosefalse に設定し、プロットで示されるデータに対応する表出力の表示を無効にします。Shuffle"every-epoch" に設定し、各エポックの最初に学習シーケンスをシャッフルします。LearnRateSchedule"piecewise" に設定し、特定のエポック数 (5) が経過するたびに、指定された係数 (0.1) によって学習率を減らします。ValidationData を検証予測子とターゲットに設定します。

この例では、適応モーメント推定 (ADAM) ソルバーを使用します。ADAM は、LSTM などの再帰型ニューラル ネットワーク (RNN) では、既定のモーメンタム項付き確率的勾配降下法 (SGDM) ソルバーよりもパフォーマンスにおいて優れています。

maxEpochs = 10;
miniBatchSize = 64;
options = trainingOptions("adam", ...
    InitialLearnRate=1e-4, ...
    MaxEpochs=maxEpochs, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    Verbose=false, ...
    ValidationFrequency=floor(numel(TrainingFeatures)/miniBatchSize), ...
    ValidationData={FeaturesValidationClean.',BaselineV}, ...
    Plots="training-progress", ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=5);

LSTM ネットワークの学習

trainNetwork を使用し、指定した学習オプションと層のアーキテクチャで LSTM ネットワークに学習させます。学習セットが大きいため、学習プロセスには数分かかる場合があります。

[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if speedupExample
    load(fullfile(netFolder,"keywordNetNoAugmentation.mat"),"keywordNetNoAugmentation","M","S");
end

ノイズを含まない検証信号に対するネットワーク精度のチェック

学習済みネットワークを使用して、検証信号の KWS マスクを推定します。

v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');

実際のラベルと推定されたラベルのベクトルから、検証用の混同行列を計算してプロットします。

figure
confusionchart(BaselineV,v, ...
    Title="Validation Accuracy", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");

ネットワークの出力を categorical から double に変換します。

v = double(v) - 1;
v = repmat(v,hopLength,1);
v = v(:);

ネットワークによって特定されたキーワードの領域を再生します。

sound(audioIn(logical(v)),fs)

推定された KWS マスクと期待される KWS マスクを可視化します。

baseline = double(BaselineV) - 1;
baseline = repmat(baseline,hopLength,1);
baseline = baseline(:);

t = (1/fs)*(0:length(v)-1);
fig = figure;
plot(t,[audioIn(1:length(v)),v,0.8*baseline])
grid on
xlabel("Time (s)")
legend("Training Signal","Network Mask","Baseline Mask",Location="southeast")
l = findall(fig,"type","line");
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title("Results for Noise-Free Speech")

ノイズを含む検証信号に対するネットワーク精度のチェック

次に、ノイズを含む音声信号に対するネットワークの精度をチェックします。ノイズを含む信号は、クリーンな検証信号を加法性ホワイト ガウス ノイズで破損させて得たものです。

ノイズを含む信号を読み込みます。

[audioInNoisy,fs] = audioread(fullfile(netFolder,"NoisyKeywordSpeech-16-16-mono-34secs.flac"));
sound(audioInNoisy,fs)

信号を可視化します。

figure
t = (1/fs)*(0:length(audioInNoisy)-1);
plot(t,audioInNoisy)
grid on
xlabel("Time (s)")
title("Noisy Validation Speech Signal")

ノイズを含む信号から特徴行列を抽出します。

featureMatrixV = extract(afe, audioInNoisy);
featureMatrixV(~isfinite(featureMatrixV)) = 0;
FeaturesValidationNoisy = (featureMatrixV - M)./S;

特徴行列をネットワークに渡します。

v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');

ネットワークの出力をベースラインと比較します。精度は、クリーンな信号で得られたものより低いことに注意してください。

figure
confusionchart(BaselineV,v, ...
    Title="Validation Accuracy - Noisy Speech", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");

ネットワークの出力を categorical から double に変換します。

v = double(v) - 1;
v = repmat(v,hopLength,1);
v = v(:);

ネットワークによって特定されたキーワードの領域を再生します。

sound(audioIn(logical(v)),fs)

推定されたマスクとベースラインのマスクを可視化します。

t = (1/fs)*(0:length(v)-1);
fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel("Time (s)")
legend("Training Signal","Network Mask","Baseline Mask",Location="southeast")
l = findall(fig,"type","line");
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title("Results for Noisy Speech - No Data Augmentation")

データ拡張の実行

学習で使用したデータセットにはノイズのない文だけが含まれているため、学習済みのネットワークはノイズを含む信号に対して良好に作動しませんでした。この問題を修正するため、データセットを拡張し、ノイズを含む文を追加します。

audioDataAugmenter (Audio Toolbox) を使用してデータセットを拡張します。

ada = audioDataAugmenter(TimeStretchProbability=0,PitchShiftProbability=0, ...
    VolumeControlProbability=0,TimeShiftProbability=0, ...
    SNRRange=[-1,1],AddNoiseProbability=0.85);

audioDataAugmenter オブジェクトは、この設定に基づいて、ガウス ホワイト ノイズを確率 85% で使用して入力オーディオ信号を破損させます。SNR は、[-1 1] (dB 単位) の範囲からランダムに選択されます。オーグメンターによって入力信号が変更されないが加えられない確率は 15% です。

例として、オーディオ信号をオーグメンターに渡します。

reset(adsKeyword)
x = read(adsKeyword);
data = augment(ada,x,fs)
data=1×2 table
         Audio          AugmentationInfo
    ________________    ________________

    {16000×1 double}       1×1 struct   

data の変数 AugmentationInfo を検査して、信号がどのように変更されたかを確認します。

data.AugmentationInfo
ans = struct with fields:
    SNR: 0.3410

データストアをリセットします。

reset(adsKeyword)
reset(adsOther)

特徴とマスクのセルを初期化します。

TrainingFeatures = {};
TrainingMasks = {};

特徴抽出を再度実行します。各信号は 85% の確率でノイズによって破損します。そのため、拡張したデータセットの約 85% のデータにノイズが含まれ、約 15% のデータにはノイズが含まれません。

tic
parfor ii = 1:numPartitions

    subadsKeyword = partition(adsKeyword,numPartitions,ii);
    subadsOther = partition(adsOther,numPartitions,ii);

    count = 1;
    localFeatures = cell(length(subadsKeyword.Files),1);
    localMasks = cell(length(subadsKeyword.Files),1);

    while hasdata(subadsKeyword)

        [sentence,mask] = synthesizeSentence(subadsKeyword,subadsOther,fs,windowLength);

        % Corrupt with noise
        augmentedData = augment(ada,sentence,fs);
        sentence = augmentedData.Audio{1};

        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;

        range = hopLength*(1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask((index-1)*hopLength+1:(index-1)*hopLength+windowLength));
        end

        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];

        count = count + 1;
    end

    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
disp("Training feature extraction took " + toc + " seconds.")
Training feature extraction took 35.6612 seconds.

各係数の平均と標準偏差を計算し、それらを使用してデータを正規化します。

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if speedupExample
    load(fullfile(netFolder,"KWSNet.mat"),"KWSNet","M","S");
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

新しい平均と標準偏差の値を使用して、検証特徴を正規化します。

FeaturesValidationNoisy = (featureMatrixV - M)./S;

拡張したデータセットによるネットワークの再学習

学習オプションを再作成します。ノイズを含むベースラインの特徴とマスクを検証に使用します。

options = trainingOptions("adam", ...
    InitialLearnRate=1e-4, ...
    MaxEpochs=maxEpochs, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    Verbose=false, ...
    ValidationFrequency=floor(numel(TrainingFeatures)/miniBatchSize), ...
    ValidationData={FeaturesValidationNoisy.',BaselineV}, ...
    Plots="training-progress", ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=5);

ネットワークに学習をさせます。

[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if speedupExample
    load(fullfile(netFolder,"KWSNet.mat"));
end

検証信号に対するネットワークの精度を確認します。

v = classify(KWSNet,FeaturesValidationNoisy.');

推定された KWS マスクと期待される KWS マスクを比較します。

figure
confusionchart(BaselineV,v, ...
    Title="Validation Accuracy with Data Augmentation", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");

特定されたキーワードの領域を再生します。

v = double(v) - 1;
v = repmat(v,hopLength,1);
v = v(:);

sound(audioIn(logical(v)),fs)

推定されたマスクと期待されるマスクを可視化します。

fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel("Time (s)")
legend("Training Signal","Network Mask","Baseline Mask",Location="southeast")
l = findall(fig,"type","line");
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title("Results for Noisy Speech - With Data Augmentation")

サポート関数

文の合成

function [sentence,mask] = synthesizeSentence(adsKeyword,adsOther,fs,minlength)

% Read one keyword
keyword = read(adsKeyword);
keyword = keyword./max(abs(keyword));

% Identify region of interest
speechIndices = detectSpeech(keyword,fs);
if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength
    speechIndices = [1,length(keyword)];
end
keyword = keyword(speechIndices(1,1):speechIndices(1,2));

% Pick a random number of other words (between 0 and 10)
numWords = randi([0,10]);
% Pick where to insert keyword
loc = randi([1,numWords+1]);
sentence = [];
mask = [];
for index = 1:numWords+1
    if index==loc
        sentence = [sentence;keyword];
        newMask = ones(size(keyword));
        mask = [mask;newMask];
    else
        other = read(adsOther);
        other = other./max(abs(other));
        sentence = [sentence;other];
        mask = [mask;zeros(size(other))];
    end
end
end

参考文献

[1] Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. Speech Commands Dataset は、Creative Commons Attribution 4.0 license に従ってライセンスされています。

参考

| | |

関連するトピック