Main Content

深層学習を使用した音声コマンド認識

この例では、オーディオに存在する音声コマンドを検出する深層学習モデルに学習させる方法を説明します。この例では、Speech Commands Dataset [1] を使用して、与えられた一連のコマンドを認識する畳み込みニューラル ネットワークに学習させます。

ネットワークにゼロから学習させるには、最初にデータセットをダウンロードしなければなりません。データセットのダウンロードやネットワークの学習を行わない場合は、この例にある学習済みのネットワークを読み込んで、以下の 2 節 事前学習済みのネットワークを使用したコマンド認識マイクからのストリーミング オーディオを使用したコマンド検出を実行することができます。

事前学習済みのネットワークを使用したコマンド認識

学習プロセスの詳細について説明する前に、事前学習済みの音声認識ネットワークを使用して音声コマンドを識別します。

事前学習済みのネットワークを読み込みます。

load("commandNet.mat")

このネットワークは、"yes""no""up""down""left""right""on""off""stop"、および "go" の音声コマンドを認識するように学習されています。

人が "stop" と発声している短い音声信号を読み込みます。

[x,fs] = audioread("stop_command.flac");

コマンドを聞きます。

sound(x,fs)

事前学習済みのネットワークは、聴覚ベースのスペクトログラムを入力として受け取ります。最初に音声の波形を聴覚ベースのスペクトログラムに変換します。

関数 helperExtractAuditoryFeature を使用して聴覚スペクトログラムを計算します。特徴抽出の詳細については、この例で後ほど説明します。

auditorySpect = helperExtractAuditoryFeatures(x,fs);

聴覚スペクトログラムに基づいてコマンドを分類します。

command = classify(trainedNet,auditorySpect)
command = categorical
     stop 

このネットワークは、このセットに属さない単語を "unknown" と分類するように学習されています。

次に、識別するコマンドの一覧に含まれていない単語 ("play") を分類します。

まず、音声信号を読み込んで聞きます。

x = audioread("play_command.flac");
sound(x,fs)

聴覚スペクトログラムを計算します。

auditorySpect = helperExtractAuditoryFeatures(x,fs);

信号を分類します。

command = classify(trainedNet,auditorySpect)
command = categorical
     unknown 

このネットワークは、バックグラウンド ノイズを "background" として分類するように学習されています。

ランダムなノイズで構成される 1 秒間の信号を作成します。

x = pinknoise(16e3);

聴覚スペクトログラムを計算します。

auditorySpect = helperExtractAuditoryFeatures(x,fs);

バックグラウンド ノイズを分類します。

command = classify(trainedNet,auditorySpect)
command = categorical
     background 

マイクからのストリーミング オーディオを使用したコマンドの検出

事前学習済みのコマンド検出ネットワークをマイクからのストリーミング オーディオでテストします。yesnostop など、いずれかのコマンドを発声してみてください。さらに、MarvinSheilabedhousecatbird などの未知の単語や、0 から 9 までの数のいずれかを発声してみてください。

分類レートを Hz 単位で指定し、マイクからオーディオを読み取ることができるオーディオ デバイス リーダーを作成します。

classificationRate = 20;
adr = audioDeviceReader(SampleRate=fs,SamplesPerFrame=floor(fs/classificationRate));

オーディオのバッファーを初期化します。ネットワークの分類ラベルを抽出します。ストリーミング オーディオのラベルと分類確率用の 0.5 秒のバッファーを初期化します。これらのバッファーを使用して長い時間にわたって分類結果を比較し、それによって、コマンドが検出されたタイミングとの '一致' を構築します。判定ロジックのしきい値を指定します。

audioBuffer = dsp.AsyncBuffer(fs);

labels = trainedNet.Layers(end).Classes;
YBuffer(1:classificationRate/2) = categorical("background");

probBuffer = zeros([numel(labels),classificationRate/2]);

countThreshold = ceil(classificationRate*0.2);
probThreshold = 0.7;

マイクからのオーディオ入力を可視化する timescope オブジェクトを作成します。予測に使用した聴覚スペクトログラムを可視化する dsp.MatrixViewer オブジェクトを作成します。

wavePlotter = timescope( ...
    SampleRate=fs, ...
    Title="...", ...
    TimeSpanSource="property", ...
    TimeSpan=1, ...
    YLimits=[-1,1], ...
    Position=[600,640,800,340], ...
    TimeAxisLabels="none", ...
    AxesScaling="manual");
show(wavePlotter)

specPlotter = dsp.MatrixViewer( ...
    XDataMode="Custom", ...
    AxisOrigin="Lower left corner", ...
    Position=[600,220,800,380], ...
    ShowGrid=false, ...
    Title="...", ...
    XLabel="Time (s)", ...
    YLabel="Bark (bin)");
show(specPlotter)

マイクからのオーディオ入力を使用して、ライブ音声コマンド認識を実行します。ループ処理を無限に実行するには、timeLimitInf に設定します。ライブ検出を停止するには、timescope と dsp.MatrixViewer の Figure を閉じます。

% Initialize variables for plotting
currentTime = 0;
colorLimits = [-1,1];

timeLimit = 10;

tic
while toc<timeLimit && isVisible(wavePlotter) && isVisible(specPlotter)
    
    % Extract audio samples from the audio device and add the samples to
    % the buffer.
    x = adr();
    write(audioBuffer,x);
    y = read(audioBuffer,fs,fs-adr.SamplesPerFrame);
    
    spec = helperExtractAuditoryFeatures(y,fs);
    
    % Classify the current spectrogram, save the label to the label buffer,
    % and save the predicted probabilities to the probability buffer.
    [YPredicted,probs] = classify(trainedNet,spec,ExecutionEnvironment="cpu");
    YBuffer = [YBuffer(2:end),YPredicted];
    probBuffer = [probBuffer(:,2:end),probs(:)];
    
    % Plot the current waveform and spectrogram.
    wavePlotter(y(end-adr.SamplesPerFrame+1:end))
    specPlotter(spec')
    
    % Now do the actual command detection by performing a thresholding operation.
    % Declare a detection and display it in the figure if the following hold: 
    %   1) The most common label is not background. 
    %   2) At last countThreshold of the latest frame labels agree. 
    %   3) The maximum probability of the predicted label is at least probThreshold.
    % Otherwise, do not declare a detection.
    [YMode,count] = mode(YBuffer);
    maxProb = max(probBuffer(labels == YMode,:));
    if YMode == "background" || count < countThreshold || maxProb < probThreshold
        wavePlotter.Title = "...";
        specPlotter.Title = "...";
    else
        wavePlotter.Title = string(YMode);
        specPlotter.Title = string(YMode);
    end
    
    % Update variables for plotting
    currentTime = currentTime + adr.SamplesPerFrame/fs;
    colorLimits = [min([colorLimits(1),min(spec,[],"all")]),max([colorLimits(2),max(spec,[],"all")])];
    specPlotter.CustomXData = [currentTime-1,currentTime];
    specPlotter.ColorLimits = colorLimits;
end
release(wavePlotter)

release(specPlotter)

音声コマンド データセットの読み込み

この例では、Google Speech Commands Datase [1] を使用します。データセットをダウンロードして、ダウンロードしたファイルを解凍します。

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","google_speech.zip");
dataFolder = tempdir;
unzip(downloadFolder,dataFolder)
dataset = fullfile(dataFolder,"google_speech");

学習データストアの作成

学習データ セットを指すaudioDatastore (Audio Toolbox)を作成します。

ads = audioDatastore(fullfile(dataset,"train"), ...
    IncludeSubfolders=true, ...
    FileExtensions=".wav", ...
    LabelSource="foldernames")
ads = 
  audioDatastore with properties:

                       Files: {
                              ' ...\AppData\Local\Temp\google_speech\train\bed\00176480_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\train\bed\004ae714_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\train\bed\004ae714_nohash_1.wav'
                               ... and 51085 more
                              }
                     Folders: {
                              'C:\Users\bhemmat\AppData\Local\Temp\google_speech\train'
                              }
                      Labels: [bed; bed; bed ... and 51085 more categorical]
    AlternateFileSystemRoots: {}
              OutputDataType: 'double'
      SupportedOutputFormats: ["wav"    "flac"    "ogg"    "mp4"    "m4a"]
         DefaultOutputFormat: "wav"

認識する単語の選択

モデルにコマンドとして認識させる単語を指定します。コマンドではないすべての単語に unknown とラベル付けします。コマンドではない単語に unknown とラベル付けすることで、コマンド以外のすべての単語の分布に近い単語のグループが作成されます。ネットワークは、このグループを使用して、コマンドと他のすべての単語の違いを学習します。

既知の単語と未知の単語の間でクラスの不均衡を減らし、処理を高速化するために、未知の単語の一部のみを学習セットに含めます。

subset (Audio Toolbox)を使用して、不明な単語のサブセットとコマンドのみが含まれるデータストアを作成します。各カテゴリに属している例の数をカウントします。

commands = categorical(["yes","no","up","down","left","right","on","off","stop","go"]);

isCommand = ismember(ads.Labels,commands);
isUnknown = ~isCommand;

includeFraction = 0.2;
mask = rand(numel(ads.Labels),1) < includeFraction;
isUnknown = isUnknown & mask;
ads.Labels(isUnknown) = categorical("unknown");

adsTrain = subset(ads,isCommand|isUnknown);
countEachLabel(adsTrain)
ans=11×2 table
     Label     Count
    _______    _____

    down       1842 
    go         1861 
    left       1839 
    no         1853 
    off        1839 
    on         1864 
    right      1852 
    stop       1885 
    unknown    6490 
    up         1843 
    yes        1860 

検証データストアの作成

検証データ セットを指すaudioDatastore (Audio Toolbox)を作成します。学習データストアの作成に用いたのと同じ手順に従います。

ads = audioDatastore(fullfile(dataset,"validation"), ...
    IncludeSubfolders=true, ...
    FileExtensions=".wav", ...
    LabelSource="foldernames")
ads = 
  audioDatastore with properties:

                       Files: {
                              ' ...\AppData\Local\Temp\google_speech\validation\bed\026290a7_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\validation\bed\060cd039_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\validation\bed\060cd039_nohash_1.wav'
                               ... and 6795 more
                              }
                     Folders: {
                              'C:\Users\bhemmat\AppData\Local\Temp\google_speech\validation'
                              }
                      Labels: [bed; bed; bed ... and 6795 more categorical]
    AlternateFileSystemRoots: {}
              OutputDataType: 'double'
      SupportedOutputFormats: ["wav"    "flac"    "ogg"    "mp4"    "m4a"]
         DefaultOutputFormat: "wav"

isCommand = ismember(ads.Labels,commands);
isUnknown = ~isCommand;

includeFraction = 0.2;
mask = rand(numel(ads.Labels),1) < includeFraction;
isUnknown = isUnknown & mask;
ads.Labels(isUnknown) = categorical("unknown");

adsValidation = subset(ads,isCommand|isUnknown);
countEachLabel(adsValidation)
ans=11×2 table
     Label     Count
    _______    _____

    down        264 
    go          260 
    left        247 
    no          270 
    off         256 
    on          257 
    right       256 
    stop        246 
    unknown     828 
    up          260 
    yes         261 

データセット全体を使ってネットワークに学習させ、できる限り精度を高くするには、speedupExamplefalse に設定します。この例を短時間で実行するには、speedupExampletrue に設定します。

speedupExample = false;
if speedupExample
    numUniqueLabels = numel(unique(adsTrain.Labels));
    % Reduce the dataset by a factor of 20
    adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files)/numUniqueLabels/20));
    adsValidation = splitEachLabel(adsValidation,round(numel(adsValidation.Files)/numUniqueLabels/20));
end

聴覚スペクトログラムの計算

畳み込みニューラル ネットワークの学習を効果的に行うためにデータを準備するには、音声波形を聴覚ベースのスペクトログラムに変換します。

特徴抽出のパラメーターを定義します。segmentDuration は各音声クリップの長さ (秒) です。frameDuration はスペクトル計算の各フレームの長さです。hopDuration は各スペクトル間のタイム ステップです。numBands は聴覚スペクトログラムのフィルター数です。

audioFeatureExtractor (Audio Toolbox)オブジェクトを作成して特徴抽出を実行します。

fs = 16e3; % Known sample rate of the data set.

segmentDuration = 1;
frameDuration = 0.025;
hopDuration = 0.010;

segmentSamples = round(segmentDuration*fs);
frameSamples = round(frameDuration*fs);
hopSamples = round(hopDuration*fs);
overlapSamples = frameSamples - hopSamples;

FFTLength = 512;
numBands = 50;

afe = audioFeatureExtractor( ...
    SampleRate=fs, ...
    FFTLength=FFTLength, ...
    Window=hann(frameSamples,"periodic"), ...
    OverlapLength=overlapSamples, ...
    barkSpectrum=true);
setExtractorParameters(afe,"barkSpectrum",NumBands=numBands,WindowNormalization=false);

データセットからファイルを読み取ります。畳み込みニューラル ネットワークに学習させるには、入力が一定サイズでなければなりません。データセットの一部のファイルは長さが 1 秒未満です。ゼロ パディングをオーディオ信号の前後に適用して長さを segmentSamples にします。

x = read(adsTrain);

numSamples = size(x,1);

numToPadFront = floor((segmentSamples - numSamples)/2);
numToPadBack = ceil((segmentSamples - numSamples)/2);

xPadded = [zeros(numToPadFront,1,"like",x);x;zeros(numToPadBack,1,"like",x)];

オーディオの特徴を抽出するには、extract を呼び出します。出力は行に沿った時間をもつバーク スペクトルです。

features = extract(afe,xPadded);
[numHops,numFeatures] = size(features)
numHops = 98
numFeatures = 50

この例では、聴覚スペクトログラムに対数を適用して後処理します。小さい数字の対数を取ると、丸め誤差の原因になります。

処理を高速化するために、parfor を使用して複数のワーカーに特徴抽出を分散できます。

最初に、データセットの区画数を決定します。Parallel Computing Toolbox™ がない場合は、単一の区画を使用します。

if ~isempty(ver("parallel")) && ~speedupExample
    pool = gcp;
    numPar = numpartitions(adsTrain,pool);
else
    numPar = 1;
end

各区画について、データストアから読み取り、信号をゼロ パディングしてから、特徴を抽出します。

parfor ii = 1:numPar
    subds = partition(adsTrain,numPar,ii);
    XTrain = zeros(numHops,numBands,1,numel(subds.Files));
    for idx = 1:numel(subds.Files)
        x = read(subds);
        xPadded = [zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)];
        XTrain(:,:,:,idx) = extract(afe,xPadded);
    end
    XTrainC{ii} = XTrain;
end

出力を変換し、4 番目の次元に聴覚スペクトログラムをもつ 4 次元配列にします。

XTrain = cat(4,XTrainC{:});

[numHops,numBands,numChannels,numSpec] = size(XTrain)
numHops = 98
numBands = 50
numChannels = 1
numSpec = 25028

ウィンドウのべき乗で特徴をスケーリングしてから対数を取ります。滑らかな分布のデータを得るために、小さいオフセットを使用してスペクトログラムの対数を取ります。

epsil = 1e-6;
XTrain = log10(XTrain + epsil);

検証セットに対して、上記で説明した特徴抽出の手順を実行します。

if ~isempty(ver("parallel"))
    pool = gcp;
    numPar = numpartitions(adsValidation,pool);
else
    numPar = 1;
end

parfor ii = 1:numPar
    subds = partition(adsValidation,numPar,ii);
    XValidation = zeros(numHops,numBands,1,numel(subds.Files));
    for idx = 1:numel(subds.Files)
        x = read(subds);
        xPadded = [zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)];
        XValidation(:,:,:,idx) = extract(afe,xPadded);
    end
    XValidationC{ii} = XValidation;
end
XValidation = cat(4,XValidationC{:});
XValidation = log10(XValidation + epsil);

学習ラベルと検証ラベルを分離します。空のカテゴリを削除します。

TTrain = removecats(adsTrain.Labels);
TValidation = removecats(adsValidation.Labels);

データの可視化

いくつかの学習サンプルについて波形と聴覚スペクトログラムをプロットします。対応するオーディオ クリップを再生します。

specMin = min(XTrain,[],"all");
specMax = max(XTrain,[],"all");
idx = randperm(numel(adsTrain.Files),3);
figure(Units="normalized",Position=[0.2 0.2 0.6 0.6]);
for ii = 1:3
    [x,fs] = audioread(adsTrain.Files{idx(ii)});

    subplot(2,3,ii)
    plot(x)
    axis tight
    title(string(adsTrain.Labels(idx(ii))))
    
    subplot(2,3,ii+3)
    spect = XTrain(:,:,1,idx(ii))';
    pcolor(spect)
    caxis([specMin specMax])
    shading flat
    
    sound(x,fs)
    pause(2)
end

バックグラウンド ノイズ データの追加

このネットワークは、発声されたさまざまな単語を認識できるだけでなく、入力に無音部分またはバックグラウンド ノイズが含まれているかどうかを検出できなければなりません。

background フォルダーのオーディオ ファイルを使用して、バックグラウンド ノイズの 1 秒間のクリップのサンプルを作成します。各バックグラウンド ノイズ ファイルから同じ数のバックグラウンド クリップを作成します。また、バックグラウンド ノイズの録音を独自に作成して、background フォルダーに追加することもできます。スペクトログラムを計算する前に、この関数は、対数一様分布からサンプリングされた係数を使用して、volumeRange で与えられた範囲に各オーディオ クリップを再スケーリングします。

adsBkg = audioDatastore(fullfile(dataset,"background"))
adsBkg = 
  audioDatastore with properties:

                       Files: {
                              ' ...\AppData\Local\Temp\google_speech\background\doing_the_dishes.wav';
                              ' ...\bhemmat\AppData\Local\Temp\google_speech\background\dude_miaowing.wav';
                              ' ...\bhemmat\AppData\Local\Temp\google_speech\background\exercise_bike.wav'
                               ... and 3 more
                              }
                     Folders: {
                              'C:\Users\bhemmat\AppData\Local\Temp\google_speech\background'
                              }
    AlternateFileSystemRoots: {}
              OutputDataType: 'double'
                      Labels: {}
      SupportedOutputFormats: ["wav"    "flac"    "ogg"    "mp4"    "m4a"]
         DefaultOutputFormat: "wav"

numBkgClips = 4000;
if speedupExample
    numBkgClips = numBkgClips/20;
end
volumeRange = log10([1e-4,1]);

numBkgFiles = numel(adsBkg.Files);
numClipsPerFile = histcounts(1:numBkgClips,linspace(1,numBkgClips,numBkgFiles+1));
Xbkg = zeros(size(XTrain,1),size(XTrain,2),1,numBkgClips,"single");
bkgAll = readall(adsBkg);
ind = 1;

for count = 1:numBkgFiles
    bkg = bkgAll{count};
    idxStart = randi(numel(bkg)-fs,numClipsPerFile(count),1);
    idxEnd = idxStart+fs-1;
    gain = 10.^((volumeRange(2)-volumeRange(1))*rand(numClipsPerFile(count),1) + volumeRange(1));
    for j = 1:numClipsPerFile(count)
        
        x = bkg(idxStart(j):idxEnd(j))*gain(j);
        
        x = max(min(x,1),-1);
        
        Xbkg(:,:,:,ind) = extract(afe,x);
        
        if mod(ind,1000)==0
            progress = "Processed " + string(ind) + " background clips out of " + string(numBkgClips)
        end
        ind = ind + 1;
    end
end
progress = 
"Processed 1000 background clips out of 4000"
progress = 
"Processed 2000 background clips out of 4000"
progress = 
"Processed 3000 background clips out of 4000"
progress = 
"Processed 4000 background clips out of 4000"
Xbkg = log10(Xbkg + epsil);

バックグラウンド ノイズのスペクトログラムを学習セット、検証セット、およびテスト セットに分割します。background noise フォルダーには約 5 分半のバックグラウンド ノイズのみが含まれているため、異なるデータ セットのバックグラウンド サンプルには高い相関があります。バックグラウンド ノイズのバリエーションを増やすために、独自のバックグラウンド ファイルを作成して、このフォルダーに追加できます。ノイズに対するネットワークのロバスト性を向上させるために、バックグラウンド ノイズを音声ファイルにミキシングしてみることもできます。

numTrainBkg = floor(0.85*numBkgClips);
numValidationBkg = floor(0.15*numBkgClips);

XTrain(:,:,:,end+1:end+numTrainBkg) = Xbkg(:,:,:,1:numTrainBkg);
TTrain(end+1:end+numTrainBkg) = "background";

XValidation(:,:,:,end+1:end+numValidationBkg) = Xbkg(:,:,:,numTrainBkg+1:end);
TValidation(end+1:end+numValidationBkg) = "background";

学習セットと検証セット内のさまざまなクラス ラベルの分布をプロットします。

figure(Units="normalized",Position=[0.2 0.2 0.5 0.5])

tiledlayout(2,1)

nexttile
histogram(TTrain)
title("Training Label Distribution")

nexttile
histogram(TValidation)
title("Validation Label Distribution")

ニューラル ネットワーク アーキテクチャの定義

シンプルなネットワーク アーキテクチャを層の配列として作成します。畳み込み層とバッチ正規化層を使用します。最大プーリング層を使って特徴マップを "空間的に" (つまり、時間と周波数に関して) ダウンサンプリングします。入力の特徴マップを時間の経過と共にグローバルにプーリングする最後の最大プーリング層を追加します。これによって (近似的な) 時間並進不変性が入力スペクトログラムに課されるため、ネットワークは、音声の正確な時間的位置とは無関係に同じ分類を実行できます。また、グローバル プーリングによって、最後の全結合層のパラメーター数が大幅に減少します。ネットワークが学習データの特定の特徴を記憶する可能性を減らすために、最後の全結合層への入力に少量のドロップアウトを追加します。

このネットワークは、フィルターがほとんどない 5 つの畳み込み層しかないため小規模です。numF によって畳み込み層のフィルターの数を制御します。ネットワークの精度を高めるために、畳み込み層、バッチ正規化層、および ReLU 層の同等なブロックを追加して、ネットワーク深さを大きくすることを試してください。numF を増やして、畳み込みフィルターの数を増やしてみることもできます。

各クラスの損失の合計重みを等しくするために、各クラスの学習例の数に反比例するクラスの重みを使用します。ネットワークの学習に Adam オプティマイザーを使用する場合、学習アルゴリズムはクラスの重み全体の正規化に依存しません。

classes = categories(TTrain);
classWeights = 1./countcats(TTrain);
classWeights = classWeights'/mean(classWeights);
numClasses = numel(categories(TTrain));

timePoolSize = ceil(numHops/8);

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer([numHops numBands])
    
    convolution2dLayer(3,numF,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(3,Stride=2,Padding="same")
    
    convolution2dLayer(3,2*numF,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(3,Stride=2,Padding="same")
    
    convolution2dLayer(3,4*numF,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(3,Stride=2,Padding="same")
    
    convolution2dLayer(3,4*numF,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer([timePoolSize,1])
    
    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer(Classes=classes,ClassWeights=classWeights)];

ネットワークの学習

学習オプションを指定します。ミニバッチ サイズを 128 として Adam オプティマイザーを使用します。学習は 25 エポック行い、20 エポック後に学習率を 10 分の 1 に下げます。

miniBatchSize = 128;
validationFrequency = floor(numel(TTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    InitialLearnRate=3e-4, ...
    MaxEpochs=25, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Verbose=false, ...
    ValidationData={XValidation,TValidation}, ...
    ValidationFrequency=validationFrequency, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=20);

ネットワークに学習をさせます。GPU がない場合、ネットワークの学習に時間がかかる場合があります。

trainedNet = trainNetwork(XTrain,TTrain,layers,options);

学習済みネットワークの評価

学習セット (データ拡張なし) と検証セットに対するネットワークの最終精度を計算します。このデータセットではネットワークは非常に正確になります。ただし、学習データ、検証データ、およびテスト データの分布はどれも似ていて、必ずしも実際の環境を反映していません。この制限は特に unknown カテゴリに当てはまります。このカテゴリには、少数の単語の発話しか含まれていません。

if speedupExample
    load("commandNet.mat","trainedNet");
end
YValPred = classify(trainedNet,XValidation);
validationError = mean(YValPred ~= TValidation);
YTrainPred = classify(trainedNet,XTrain);
trainError = mean(YTrainPred ~= TTrain);

disp("Training error: " + trainError*100 + "%")
Training error: 1.5794%
disp("Validation error: " + validationError*100 + "%")
Validation error: 4.6692%

混同行列をプロットします。列と行の要約を使用して、各クラスの適合率と再現率を表示します。混同行列のクラスを並べ替えます。大きな認識の違いが現れるのは、未知の単語間、コマンド upoffdownno、および gono の間にあります。

figure(Units="normalized",Position=[0.2 0.2 0.5 0.5]);
cm = confusionchart(TValidation,YValPred, ...
    Title="Confusion Matrix for Validation Data", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");
sortClasses(cm,[commands,"unknown","background"])

モバイル用途など、ハードウェア リソースに制約がある用途で使用する場合は、利用可能なメモリおよび計算リソースの制限を考慮します。CPU を使用する場合は、ネットワークの合計サイズを KB 単位で計算し、その予測速度をテストします。予測時間は 1 つの入力イメージの分類にかかる時間です。複数のイメージをネットワークに入力する場合、これらを同時に分類して、イメージあたりの予測時間を短くすることができます。ストリーミング オーディオを分類する場合、1 つのイメージの予測時間が最も重要です。

info = whos("trainedNet");
disp("Network size: " + info.bytes/1024 + " kB")
Network size: 292.2139 kB
time = zeros(100,1);
for ii = 1:100
    x = randn([numHops,numBands]);
    tic
    [YPredicted,probs] = classify(trainedNet,x,ExecutionEnvironment="cpu");
    time(ii) = toc;
end
disp("Single-image prediction time on CPU: " + mean(time(11:end))*1000 + " ms")
Single-image prediction time on CPU: 2.4838 ms

参考文献

[1] Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. Speech Commands Dataset は、次で公開されている Creative Commons Attribution 4.0 license に従ってライセンスされています。https://creativecommons.org/licenses/by/4.0/legalcode

参照

[1] Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. The Speech Commands Dataset is licensed under the Creative Commons Attribution 4.0 license, available here: https://creativecommons.org/licenses/by/4.0/legalcode.

参考

| |

関連するトピック