Main Content

深層学習を使用したプリエンファシス フィルターの学習

この例では、畳み込み深層ネットワークを使用して音声認識用のプリエンファシス フィルターの学習を行う方法を示します。この例では、学習可能な短時間フーリエ変換 (STFT) 層を使用して、2 次元畳み込み層での使用に適した時間-周波数表現を取得します。学習可能な STFT を使用すると、プリエンファシス フィルターの重みを勾配ベースで最適化できます。

データ

Free Spoken Digit データセット (FSDD) をクローンするかまたはダウンロードします。このデータセットは、https://github.com/Jakobovski/free-spoken-digit-dataset から入手できます。FSDD はオープンなデータセットなので、時間の経過とともに大きくなる可能性があります。この例では、2020 年 8 月 20 日にまとめられたバージョンを使用します。このバージョンには、6 人の話者が発話した 0 ~ 9 の英語の数字の録音が 3000 件含まれています。データは 8000 Hz でサンプリングされます。

この例では、MATLAB の tempdir の値に対応するフォルダーにデータをダウンロードしているものと仮定します。別のフォルダーを使用する場合は、次のコードの tempdir をそのフォルダー名に置き換えます。audioDatastore を使用してデータ アクセスを管理し、データが必ず学習セットとテスト セットにランダムに分割されるようにします。

pathToRecordingsFolder = fullfile(tempdir,'free-spoken-digit-dataset','recordings');
ads = audioDatastore(pathToRecordingsFolder);

関数 filenames2labels を使用して、FSDD ファイルからラベルの categorical ベクトルを取得します。データ セット内の各ラベルの数を表示します。

lbls = filenames2labels(ads,ExtractBefore="_");
ads.Labels = lbls;
countlabels(lbls)
ans=10×3 table
    Label    Count    Percent
    _____    _____    _______

      0       300       10   
      1       300       10   
      2       300       10   
      3       300       10   
      4       300       10   
      5       300       10   
      6       300       10   
      7       300       10   
      8       300       10   
      9       300       10   

各サブセットで均等なクラス比率を維持するように、FSDD を学習セットとテスト セットに分割します。再現可能な結果を得るために、乱数発生器を既定値に設定します。80%、つまり 2400 件の録音が学習に使用されます。残りの 600 件の録音、つまり全体の 20% がテスト用に取り分けられます。学習セットとテスト セットを作成する前に、データストア内のファイルを 1 回シャッフルします。

rng default;
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8,0.2);

FSDD に含まれる録音の長さは同じではありません。変換を使用し、データストアからの各読み取りが 8192 サンプルになるようにパディングまたは打ち切られるようにします。データをさらに単精度にキャストし、z スコア正規化を適用します。

transTrain = transform(adsTrain,@(x,info)helperReadData(x,info),'IncludeInfo',true);
transTest = transform(adsTest,@(x,info)helperReadData(x,info),'IncludeInfo',true);

深層畳み込みニューラル ネットワーク (DCNN) アーキテクチャ

この例では、次の深層畳み込みネットワークでカスタム学習ループを使用します。

numF = 12;
dropoutProb = 0.2;
layers = [
    sequenceInputLayer(1,'Name','input','MinLength',8192,...
         'Normalization',"none")

    convolution1dLayer(5,1,"name","pre-emphasis-filter",...
    "WeightsInitializer",@(sz)kronDelta(sz),"BiasLearnRateFactor",0)  

    stftLayer('Window',hamming(1280),'OverlapLength',900,...
    'Name','STFT') 
    
    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')
    
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')
   
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    dropoutLayer(dropoutProb)
    globalAveragePooling2dLayer
    fullyConnectedLayer(numel(categories(ads.Labels)))
    softmaxLayer    
    ];
dlnet = dlnetwork(layers);

シーケンス入力層の後には、5 つの係数をもつ単一のフィルターで構成される 1 次元畳み込み層が続きます。これは有限インパルス応答フィルターです。深層学習ネットワークの畳み込み層は、既定で入力特徴に対してアフィン演算を実装します。厳密な線形 (フィルター処理) 演算を取得するには、'BiasInitializer' の既定値 ('zeros') を使用し、層のバイアス学習率係数を 0 に設定します。これは、バイアスが 0 に初期化され、学習中に変更されないことを意味します。ネットワークは、スケーリングされたクロネッカー デルタ シーケンスとなるようなフィルター重みのカスタム初期化を使用します。これはオールパス フィルターであり、入力のフィルター処理を実行しません。オールパス フィルターの重み初期化子のコードをここに示します。

function delta = kronDelta(sz)
% This function is only for use in the "Learn Pre-Emphasis Filter using
% Deep Learning" example. It may change or be removed in a
% future release.

L = sz(1);
delta = zeros(L,sz(2),sz(3),'single');
delta(1) = 1/sqrt(L);

end

stftLayer は、フィルター処理された入力信号のバッチを取得し、その振幅 STFT を取得します。振幅 STFT は信号の 2 次元表現であり、2 次元畳み込みネットワークでの使用に適しています。

ここで、STFT の重みは学習中に変更されませんが、この層は逆伝播をサポートしており、これにより "プリエンファシスフィルター" 層のフィルター係数を学習できるようになります。

ネットワークの学習

カスタム学習ループの学習オプションを設定します。ミニバッチ サイズ 128 で 70 エポックを使用します。初期学習率を 0.001 に設定します。

NumEpochs = 70;
miniBatchSize = 128;
learnRate = 0.001;

カスタム学習ループでは、minibatchqueue オブジェクトを使用します。関数 processSpeechMB は、ミニバッチを読み取り、one-hot 符号化スキームをラベルに適用します。

mbqTrain = minibatchqueue(transTrain,2,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFormat', {'CBT','CB'}, ... 
    'MiniBatchFcn', @processSpeechMB);

ネットワークに学習させ、反復ごとに損失をプロットします。Adam オプティマイザーを使用して、ネットワークの学習可能なパラメーターを更新します。学習の進行状況に応じて損失をプロットするには、次のコードの progress の値を "training-progress" に設定します。

progress = "final-loss";
if progress == "training-progress"
    figure
    lineLossTrain = animatedline;
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

% Initialize some training loop variables
trailingAvg = [];
trailingAvgSq = [];
iteration = 0;
lossByIteration = 0;

% Loop over epochs and time the epochs
start = tic;

for epoch = 1:NumEpochs
    reset(mbqTrain)
    shuffle(mbqTrain)

    % Loop over mini-batches
    while hasdata(mbqTrain)
        iteration = iteration + 1;
        
        % Get the next minibatch and one-hot coded targets
        [dlX,Y] = next(mbqTrain);
        
        % Evaluate the model gradients and loss 
        [gradients, loss, state] = dlfeval(@modelGradSTFT,dlnet,dlX,Y);
        if progress == "final-loss"
            lossByIteration(iteration) = loss;
        end

        % Update the network state
        dlnet.State = state;
        
        % Update the network parameters using an Adam optimizer
        [dlnet,trailingAvg,trailingAvgSq] = adamupdate(...
            dlnet,gradients,trailingAvg,trailingAvgSq,iteration,learnRate);        
        
        % Display the training progress
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        if progress == "training-progress"
            addpoints(lineLossTrain,iteration,loss)
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
        end
        
    end
    disp("Training loss after epoch " + epoch + ": " + loss); 

end
Training loss after epoch 1: 1.5686
Training loss after epoch 2: 1.2063
Training loss after epoch 3: 0.70384
Training loss after epoch 4: 0.50291
Training loss after epoch 5: 0.35332
Training loss after epoch 6: 0.22536
Training loss after epoch 7: 0.14302
Training loss after epoch 8: 0.14749
Training loss after epoch 9: 0.1436
Training loss after epoch 10: 0.092127
Training loss after epoch 11: 0.053437
Training loss after epoch 12: 0.059123
Training loss after epoch 13: 0.07433
Training loss after epoch 14: 0.066282
Training loss after epoch 15: 0.11964
Training loss after epoch 16: 0.087663
Training loss after epoch 17: 0.069451
Training loss after epoch 18: 0.11175
Training loss after epoch 19: 0.044604
Training loss after epoch 20: 0.064503
Training loss after epoch 21: 0.050275
Training loss after epoch 22: 0.022125
Training loss after epoch 23: 0.092534
Training loss after epoch 24: 0.1393
Training loss after epoch 25: 0.015846
Training loss after epoch 26: 0.022516
Training loss after epoch 27: 0.01798
Training loss after epoch 28: 0.012391
Training loss after epoch 29: 0.0068496
Training loss after epoch 30: 0.036968
Training loss after epoch 31: 0.014514
Training loss after epoch 32: 0.0055389
Training loss after epoch 33: 0.0080868
Training loss after epoch 34: 0.0097247
Training loss after epoch 35: 0.0067841
Training loss after epoch 36: 0.0073048
Training loss after epoch 37: 0.0068763
Training loss after epoch 38: 0.064052
Training loss after epoch 39: 0.029343
Training loss after epoch 40: 0.055245
Training loss after epoch 41: 0.20821
Training loss after epoch 42: 0.052951
Training loss after epoch 43: 0.034677
Training loss after epoch 44: 0.020905
Training loss after epoch 45: 0.077562
Training loss after epoch 46: 0.0055673
Training loss after epoch 47: 0.015712
Training loss after epoch 48: 0.011886
Training loss after epoch 49: 0.0063345
Training loss after epoch 50: 0.0030241
Training loss after epoch 51: 0.0033596
Training loss after epoch 52: 0.0042235
Training loss after epoch 53: 0.0054001
Training loss after epoch 54: 0.0037229
Training loss after epoch 55: 0.0042717
Training loss after epoch 56: 0.0030938
Training loss after epoch 57: 0.0024514
Training loss after epoch 58: 0.005746
Training loss after epoch 59: 0.0027509
Training loss after epoch 60: 0.0069394
Training loss after epoch 61: 0.0024441
Training loss after epoch 62: 0.0054856
Training loss after epoch 63: 0.0012796
Training loss after epoch 64: 0.0013482
Training loss after epoch 65: 0.0038288
Training loss after epoch 66: 0.0013217
Training loss after epoch 67: 0.0022817
Training loss after epoch 68: 0.0025086
Training loss after epoch 69: 0.0013634
Training loss after epoch 70: 0.0014228
if progress == "final-loss"
        plot(1:iteration,lossByIteration)
        grid on 
        title('Training Loss by Iteration')
        xlabel("Iteration")
        ylabel("Loss")
end

Figure contains an axes object. The axes object with title Training Loss by Iteration contains an object of type line.

学習済みのネットワークを、ホールドアウトされたテスト セットでテストします。ミニバッチ サイズ 32 で minibatchqueue オブジェクトを使用します。

miniBatchSize = 32;
mbqTest = minibatchqueue(transTest,2,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFormat', {'CBT','CB'}, ... 
    'MiniBatchFcn', @processSpeechMB);

テスト セットをループ処理し、各ミニバッチのクラス ラベルを予測します。

numObservations = numel(adsTest.Files);
classes = string(unique(adsTest.Labels));

predictions = [];

% Loop over mini-batches
while hasdata(mbqTest)    
    % Read mini-batch of data
    dlX = next(mbqTest);

    % Make predictions on the minibatch
    dlYPred = predict(dlnet,dlX);

    % Determine corresponding classes
    predBatch = onehotdecode(dlYPred,classes,1);
    predictions = [predictions predBatch];  
end

ホールドアウトされたテスト セット内の 600 個の例で分類精度を評価します。

accuracy = mean(predictions' == categorical(adsTest.Labels))
accuracy = 0.9883

テストの性能は約 99% です。1 次元畳み込み層をコメントアウトして、プリエンファシス フィルターを使用せずにネットワークに再学習させることができます。プリエンファシス フィルターを使用しない場合のテスト パフォーマンスも約 96% と優れていますが、プリエンファシス フィルターを使用した方がわずかに改善されます。学習済みのプリエンファシス フィルターの使用によってテスト精度はわずかに向上しただけですが、注目すべき点は、たった 5 つの学習可能パラメーターをネットワークに追加することによってこれが達成されたことです。

学習済みのプリエンファシス フィルターを調べるには、1 次元畳み込み層の重みを抽出します。周波数応答をプロットします。データのサンプリング周波数が 8 kHz であることを思い出してください。フィルターはスケーリングされたクロネッカー デルタ シーケンス (オールパス フィルター) に初期化されているため、初期化済みフィルターの周波数応答と学習済みフィルターの応答を簡単に比較できます。

FIRFilter = dlnet.Layers(2).Weights;
[H,W] = freqz(FIRFilter,1,[],8000);
delta = kronDelta([5 1 1]);
Hinit = freqz(delta,1,[],4000);
plot(W,20*log10(abs([H Hinit])),'linewidth',2)
grid on
xlabel('Hz')
ylabel('dB')
legend('Learned Filter','Initial Filter','Location','SouthEast')
title('Learned Pre-emphasis Filter')

Figure contains an axes object. The axes object with title Learned Pre-emphasis Filter contains 2 objects of type line. These objects represent Learned Filter, Initial Filter.

この例では、信号の短時間フーリエ変換に基づいて、2 次元畳み込みネットワークの前処理ステップとしてプリエンファシス フィルターの学習を行う方法を示しました。逆伝播をサポートする stftLayer の機能により、深層ネットワーク内のフィルター重みを勾配ベースで最適化できるようになりました。この結果は、テスト セットに対するネットワーク性能がわずかに改善しただけでしたが、この改善は、学習可能なパラメーター数の微細な増加によって達成されました。

付録: 補助関数

function [out,info] = helperReadData(x,info)
% This function is only for use in the "Learn Pre-Emphasis Filter using
% Deep Learning" example. It may change or be removed in a
% future release.

N = numel(x);
x = single(x);
if N > 8192
    x = x(1:8192);
elseif N < 8192
    pad = 8192-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = (x-mean(x))./std(x);
x = x(:)';
out = {x,info.Label};
end

function [dlX,dlY] = processSpeechMB(Xcell,Ycell)
% This function is only for use in the "Learn Pre-Emphasis Filter using
% Deep Learning" example. It may change or be removed in a
% future release.

Xcell = cellfun(@(x)reshape(x,1,1,[]),Xcell,'uni',false);
dlX = cat(2,Xcell{:});
dlY = cat(2,Ycell{:});
dlY = onehotencode(dlY,1);
end

function [grads,loss,state] = modelGradSTFT(net,X,T)
% This function is only for use in the "Learn Pre-Emphasis Filter using
% Deep Learning" example. It may change or be removed in a
% future release.

[y,state] = net.forward(X);
loss = crossentropy(y,T);
grads = dlgradient(loss,net.Learnables);
loss = double(gather(extractdata(loss)));
end

参考

アプリ

オブジェクト

関数

  • (Signal Processing Toolbox) | (Signal Processing Toolbox) | (Signal Processing Toolbox) | (Signal Processing Toolbox) | (Signal Processing Toolbox)

関連するトピック