深層学習を使用した音声コマンド認識
この例では、オーディオに存在する音声コマンドを検出する深層学習モデルに学習させる方法を説明します。この例では、Speech Commands Dataset [1] を使用して、与えられた一連のコマンドを認識する畳み込みニューラル ネットワークに学習させます。
ネットワークにゼロから学習させるには、最初にデータセットをダウンロードしなければなりません。データセットのダウンロードやネットワークの学習を行わない場合は、この例にある学習済みのネットワークを読み込んで、以下の 2 節 事前学習済みのネットワークを使用したコマンド認識とマイクからのストリーミング オーディオを使用したコマンド検出を実行することができます。
事前学習済みのネットワークを使用したコマンド認識
学習プロセスの詳細について説明する前に、事前学習済みの音声認識ネットワークを使用して音声コマンドを識別します。
事前学習済みのネットワークを読み込みます。
load("commandNet.mat")
このネットワークは、"yes"、"no"、"up"、"down"、"left"、"right"、"on"、"off"、"stop"、および "go" の音声コマンドを認識するように学習されています。
人が "stop" と発声している短い音声信号を読み込みます。
[x,fs] = audioread("stop_command.flac");
コマンドを聞きます。
sound(x,fs)
事前学習済みのネットワークは、聴覚ベースのスペクトログラムを入力として受け取ります。最初に音声の波形を聴覚ベースのスペクトログラムに変換します。
関数 helperExtractAuditoryFeature
を使用して聴覚スペクトログラムを計算します。特徴抽出の詳細については、この例で後ほど説明します。
auditorySpect = helperExtractAuditoryFeatures(x,fs);
聴覚スペクトログラムに基づいてコマンドを分類します。
command = classify(trainedNet,auditorySpect)
command = categorical
stop
このネットワークは、このセットに属さない単語を "unknown" と分類するように学習されています。
次に、識別するコマンドの一覧に含まれていない単語 ("play") を分類します。
まず、音声信号を読み込んで聞きます。
x = audioread("play_command.flac");
sound(x,fs)
聴覚スペクトログラムを計算します。
auditorySpect = helperExtractAuditoryFeatures(x,fs);
信号を分類します。
command = classify(trainedNet,auditorySpect)
command = categorical
unknown
このネットワークは、バックグラウンド ノイズを "background" として分類するように学習されています。
ランダムなノイズで構成される 1 秒間の信号を作成します。
x = pinknoise(16e3);
聴覚スペクトログラムを計算します。
auditorySpect = helperExtractAuditoryFeatures(x,fs);
バックグラウンド ノイズを分類します。
command = classify(trainedNet,auditorySpect)
command = categorical
background
マイクからのストリーミング オーディオを使用したコマンドの検出
事前学習済みのコマンド検出ネットワークをマイクからのストリーミング オーディオでテストします。yes、no、stop など、いずれかのコマンドを発声してみてください。さらに、Marvin、Sheila、bed、house、cat、bird などの未知の単語や、0 から 9 までの数のいずれかを発声してみてください。
分類レートを Hz 単位で指定し、マイクからオーディオを読み取ることができるオーディオ デバイス リーダーを作成します。
classificationRate = 20; adr = audioDeviceReader(SampleRate=fs,SamplesPerFrame=floor(fs/classificationRate));
オーディオのバッファーを初期化します。ネットワークの分類ラベルを抽出します。ストリーミング オーディオのラベルと分類確率用の 0.5 秒のバッファーを初期化します。これらのバッファーを使用して長い時間にわたって分類結果を比較し、それによって、コマンドが検出されたタイミングとの '一致' を構築します。判定ロジックのしきい値を指定します。
audioBuffer = dsp.AsyncBuffer(fs);
labels = trainedNet.Layers(end).Classes;
YBuffer(1:classificationRate/2) = categorical("background");
probBuffer = zeros([numel(labels),classificationRate/2]);
countThreshold = ceil(classificationRate*0.2);
probThreshold = 0.7;
マイクからのオーディオ入力を可視化する timescope
オブジェクトを作成します。予測に使用した聴覚スペクトログラムを可視化する dsp.MatrixViewer
オブジェクトを作成します。
wavePlotter = timescope( ... SampleRate=fs, ... Title="...", ... TimeSpanSource="property", ... TimeSpan=1, ... YLimits=[-1,1], ... Position=[600,640,800,340], ... TimeAxisLabels="none", ... AxesScaling="manual"); show(wavePlotter) specPlotter = dsp.MatrixViewer( ... XDataMode="Custom", ... AxisOrigin="Lower left corner", ... Position=[600,220,800,380], ... ShowGrid=false, ... Title="...", ... XLabel="Time (s)", ... YLabel="Bark (bin)"); show(specPlotter)
マイクからのオーディオ入力を使用して、ライブ音声コマンド認識を実行します。ループ処理を無限に実行するには、timeLimit
を Inf
に設定します。ライブ検出を停止するには、timescope と dsp.MatrixViewer
の Figure を閉じます。
% Initialize variables for plotting currentTime = 0; colorLimits = [-1,1]; timeLimit = 10; tic while toc<timeLimit && isVisible(wavePlotter) && isVisible(specPlotter) % Extract audio samples from the audio device and add the samples to % the buffer. x = adr(); write(audioBuffer,x); y = read(audioBuffer,fs,fs-adr.SamplesPerFrame); spec = helperExtractAuditoryFeatures(y,fs); % Classify the current spectrogram, save the label to the label buffer, % and save the predicted probabilities to the probability buffer. [YPredicted,probs] = classify(trainedNet,spec,ExecutionEnvironment="cpu"); YBuffer = [YBuffer(2:end),YPredicted]; probBuffer = [probBuffer(:,2:end),probs(:)]; % Plot the current waveform and spectrogram. wavePlotter(y(end-adr.SamplesPerFrame+1:end)) specPlotter(spec') % Now do the actual command detection by performing a thresholding operation. % Declare a detection and display it in the figure if the following hold: % 1) The most common label is not background. % 2) At last countThreshold of the latest frame labels agree. % 3) The maximum probability of the predicted label is at least probThreshold. % Otherwise, do not declare a detection. [YMode,count] = mode(YBuffer); maxProb = max(probBuffer(labels == YMode,:)); if YMode == "background" || count < countThreshold || maxProb < probThreshold wavePlotter.Title = "..."; specPlotter.Title = "..."; else wavePlotter.Title = string(YMode); specPlotter.Title = string(YMode); end % Update variables for plotting currentTime = currentTime + adr.SamplesPerFrame/fs; colorLimits = [min([colorLimits(1),min(spec,[],"all")]),max([colorLimits(2),max(spec,[],"all")])]; specPlotter.CustomXData = [currentTime-1,currentTime]; specPlotter.ColorLimits = colorLimits; end release(wavePlotter)
release(specPlotter)
音声コマンド データセットの読み込み
この例では、Google Speech Commands Datase [1] を使用します。データセットをダウンロードして、ダウンロードしたファイルを解凍します。
downloadFolder = matlab.internal.examples.downloadSupportFile("audio","google_speech.zip"); dataFolder = tempdir; unzip(downloadFolder,dataFolder) dataset = fullfile(dataFolder,"google_speech");
学習データストアの作成
学習データ セットを指すaudioDatastore
(Audio Toolbox)を作成します。
ads = audioDatastore(fullfile(dataset,"train"), ... IncludeSubfolders=true, ... FileExtensions=".wav", ... LabelSource="foldernames")
ads = audioDatastore with properties: Files: { ' ...\AppData\Local\Temp\google_speech\train\bed\00176480_nohash_0.wav'; ' ...\AppData\Local\Temp\google_speech\train\bed\004ae714_nohash_0.wav'; ' ...\AppData\Local\Temp\google_speech\train\bed\004ae714_nohash_1.wav' ... and 51085 more } Folders: { 'C:\Users\bhemmat\AppData\Local\Temp\google_speech\train' } Labels: [bed; bed; bed ... and 51085 more categorical] AlternateFileSystemRoots: {} OutputDataType: 'double' SupportedOutputFormats: ["wav" "flac" "ogg" "mp4" "m4a"] DefaultOutputFormat: "wav"
認識する単語の選択
モデルにコマンドとして認識させる単語を指定します。コマンドではないすべての単語に unknown
とラベル付けします。コマンドではない単語に unknown
とラベル付けすることで、コマンド以外のすべての単語の分布に近い単語のグループが作成されます。ネットワークは、このグループを使用して、コマンドと他のすべての単語の違いを学習します。
既知の単語と未知の単語の間でクラスの不均衡を減らし、処理を高速化するために、未知の単語の一部のみを学習セットに含めます。
subset
(Audio Toolbox)を使用して、不明な単語のサブセットとコマンドのみが含まれるデータストアを作成します。各カテゴリに属している例の数をカウントします。
commands = categorical(["yes","no","up","down","left","right","on","off","stop","go"]); isCommand = ismember(ads.Labels,commands); isUnknown = ~isCommand; includeFraction = 0.2; mask = rand(numel(ads.Labels),1) < includeFraction; isUnknown = isUnknown & mask; ads.Labels(isUnknown) = categorical("unknown"); adsTrain = subset(ads,isCommand|isUnknown); countEachLabel(adsTrain)
ans=11×2 table
Label Count
_______ _____
down 1842
go 1861
left 1839
no 1853
off 1839
on 1864
right 1852
stop 1885
unknown 6490
up 1843
yes 1860
検証データストアの作成
検証データ セットを指すaudioDatastore
(Audio Toolbox)を作成します。学習データストアの作成に用いたのと同じ手順に従います。
ads = audioDatastore(fullfile(dataset,"validation"), ... IncludeSubfolders=true, ... FileExtensions=".wav", ... LabelSource="foldernames")
ads = audioDatastore with properties: Files: { ' ...\AppData\Local\Temp\google_speech\validation\bed\026290a7_nohash_0.wav'; ' ...\AppData\Local\Temp\google_speech\validation\bed\060cd039_nohash_0.wav'; ' ...\AppData\Local\Temp\google_speech\validation\bed\060cd039_nohash_1.wav' ... and 6795 more } Folders: { 'C:\Users\bhemmat\AppData\Local\Temp\google_speech\validation' } Labels: [bed; bed; bed ... and 6795 more categorical] AlternateFileSystemRoots: {} OutputDataType: 'double' SupportedOutputFormats: ["wav" "flac" "ogg" "mp4" "m4a"] DefaultOutputFormat: "wav"
isCommand = ismember(ads.Labels,commands);
isUnknown = ~isCommand;
includeFraction = 0.2;
mask = rand(numel(ads.Labels),1) < includeFraction;
isUnknown = isUnknown & mask;
ads.Labels(isUnknown) = categorical("unknown");
adsValidation = subset(ads,isCommand|isUnknown);
countEachLabel(adsValidation)
ans=11×2 table
Label Count
_______ _____
down 264
go 260
left 247
no 270
off 256
on 257
right 256
stop 246
unknown 828
up 260
yes 261
データセット全体を使ってネットワークに学習させ、できる限り精度を高くするには、speedupExample
を false
に設定します。この例を短時間で実行するには、speedupExample
を true
に設定します。
speedupExample =false; if speedupExample numUniqueLabels = numel(unique(adsTrain.Labels)); % Reduce the dataset by a factor of 20 adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files)/numUniqueLabels/20)); adsValidation = splitEachLabel(adsValidation,round(numel(adsValidation.Files)/numUniqueLabels/20)); end
聴覚スペクトログラムの計算
畳み込みニューラル ネットワークの学習を効果的に行うためにデータを準備するには、音声波形を聴覚ベースのスペクトログラムに変換します。
特徴抽出のパラメーターを定義します。segmentDuration
は各音声クリップの長さ (秒) です。frameDuration
はスペクトル計算の各フレームの長さです。hopDuration
は各スペクトル間のタイム ステップです。numBands
は聴覚スペクトログラムのフィルター数です。
audioFeatureExtractor
(Audio Toolbox)オブジェクトを作成して特徴抽出を実行します。
fs = 16e3; % Known sample rate of the data set. segmentDuration = 1; frameDuration = 0.025; hopDuration = 0.010; segmentSamples = round(segmentDuration*fs); frameSamples = round(frameDuration*fs); hopSamples = round(hopDuration*fs); overlapSamples = frameSamples - hopSamples; FFTLength = 512; numBands = 50; afe = audioFeatureExtractor( ... SampleRate=fs, ... FFTLength=FFTLength, ... Window=hann(frameSamples,"periodic"), ... OverlapLength=overlapSamples, ... barkSpectrum=true); setExtractorParameters(afe,"barkSpectrum",NumBands=numBands,WindowNormalization=false);
データセットからファイルを読み取ります。畳み込みニューラル ネットワークに学習させるには、入力が一定サイズでなければなりません。データセットの一部のファイルは長さが 1 秒未満です。ゼロ パディングをオーディオ信号の前後に適用して長さを segmentSamples
にします。
x = read(adsTrain); numSamples = size(x,1); numToPadFront = floor((segmentSamples - numSamples)/2); numToPadBack = ceil((segmentSamples - numSamples)/2); xPadded = [zeros(numToPadFront,1,"like",x);x;zeros(numToPadBack,1,"like",x)];
オーディオの特徴を抽出するには、extract
を呼び出します。出力は行に沿った時間をもつバーク スペクトルです。
features = extract(afe,xPadded); [numHops,numFeatures] = size(features)
numHops = 98
numFeatures = 50
この例では、聴覚スペクトログラムに対数を適用して後処理します。小さい数字の対数を取ると、丸め誤差の原因になります。
処理を高速化するために、parfor
を使用して複数のワーカーに特徴抽出を分散できます。
最初に、データセットの区画数を決定します。Parallel Computing Toolbox™ がない場合は、単一の区画を使用します。
if ~isempty(ver("parallel")) && ~speedupExample pool = gcp; numPar = numpartitions(adsTrain,pool); else numPar = 1; end
各区画について、データストアから読み取り、信号をゼロ パディングしてから、特徴を抽出します。
parfor ii = 1:numPar subds = partition(adsTrain,numPar,ii); XTrain = zeros(numHops,numBands,1,numel(subds.Files)); for idx = 1:numel(subds.Files) x = read(subds); xPadded = [zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)]; XTrain(:,:,:,idx) = extract(afe,xPadded); end XTrainC{ii} = XTrain; end
出力を変換し、4 番目の次元に聴覚スペクトログラムをもつ 4 次元配列にします。
XTrain = cat(4,XTrainC{:}); [numHops,numBands,numChannels,numSpec] = size(XTrain)
numHops = 98
numBands = 50
numChannels = 1
numSpec = 25028
ウィンドウのべき乗で特徴をスケーリングしてから対数を取ります。滑らかな分布のデータを得るために、小さいオフセットを使用してスペクトログラムの対数を取ります。
epsil = 1e-6; XTrain = log10(XTrain + epsil);
検証セットに対して、上記で説明した特徴抽出の手順を実行します。
if ~isempty(ver("parallel")) pool = gcp; numPar = numpartitions(adsValidation,pool); else numPar = 1; end parfor ii = 1:numPar subds = partition(adsValidation,numPar,ii); XValidation = zeros(numHops,numBands,1,numel(subds.Files)); for idx = 1:numel(subds.Files) x = read(subds); xPadded = [zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)]; XValidation(:,:,:,idx) = extract(afe,xPadded); end XValidationC{ii} = XValidation; end XValidation = cat(4,XValidationC{:}); XValidation = log10(XValidation + epsil);
学習ラベルと検証ラベルを分離します。空のカテゴリを削除します。
TTrain = removecats(adsTrain.Labels); TValidation = removecats(adsValidation.Labels);
データの可視化
いくつかの学習サンプルについて波形と聴覚スペクトログラムをプロットします。対応するオーディオ クリップを再生します。
specMin = min(XTrain,[],"all"); specMax = max(XTrain,[],"all"); idx = randperm(numel(adsTrain.Files),3); figure(Units="normalized",Position=[0.2 0.2 0.6 0.6]); for ii = 1:3 [x,fs] = audioread(adsTrain.Files{idx(ii)}); subplot(2,3,ii) plot(x) axis tight title(string(adsTrain.Labels(idx(ii)))) subplot(2,3,ii+3) spect = XTrain(:,:,1,idx(ii))'; pcolor(spect) caxis([specMin specMax]) shading flat sound(x,fs) pause(2) end
バックグラウンド ノイズ データの追加
このネットワークは、発声されたさまざまな単語を認識できるだけでなく、入力に無音部分またはバックグラウンド ノイズが含まれているかどうかを検出できなければなりません。
background フォルダーのオーディオ ファイルを使用して、バックグラウンド ノイズの 1 秒間のクリップのサンプルを作成します。各バックグラウンド ノイズ ファイルから同じ数のバックグラウンド クリップを作成します。また、バックグラウンド ノイズの録音を独自に作成して、background フォルダーに追加することもできます。スペクトログラムを計算する前に、この関数は、対数一様分布からサンプリングされた係数を使用して、volumeRange
で与えられた範囲に各オーディオ クリップを再スケーリングします。
adsBkg = audioDatastore(fullfile(dataset,"background"))
adsBkg = audioDatastore with properties: Files: { ' ...\AppData\Local\Temp\google_speech\background\doing_the_dishes.wav'; ' ...\bhemmat\AppData\Local\Temp\google_speech\background\dude_miaowing.wav'; ' ...\bhemmat\AppData\Local\Temp\google_speech\background\exercise_bike.wav' ... and 3 more } Folders: { 'C:\Users\bhemmat\AppData\Local\Temp\google_speech\background' } AlternateFileSystemRoots: {} OutputDataType: 'double' Labels: {} SupportedOutputFormats: ["wav" "flac" "ogg" "mp4" "m4a"] DefaultOutputFormat: "wav"
numBkgClips = 4000; if speedupExample numBkgClips = numBkgClips/20; end volumeRange = log10([1e-4,1]); numBkgFiles = numel(adsBkg.Files); numClipsPerFile = histcounts(1:numBkgClips,linspace(1,numBkgClips,numBkgFiles+1)); Xbkg = zeros(size(XTrain,1),size(XTrain,2),1,numBkgClips,"single"); bkgAll = readall(adsBkg); ind = 1; for count = 1:numBkgFiles bkg = bkgAll{count}; idxStart = randi(numel(bkg)-fs,numClipsPerFile(count),1); idxEnd = idxStart+fs-1; gain = 10.^((volumeRange(2)-volumeRange(1))*rand(numClipsPerFile(count),1) + volumeRange(1)); for j = 1:numClipsPerFile(count) x = bkg(idxStart(j):idxEnd(j))*gain(j); x = max(min(x,1),-1); Xbkg(:,:,:,ind) = extract(afe,x); if mod(ind,1000)==0 progress = "Processed " + string(ind) + " background clips out of " + string(numBkgClips) end ind = ind + 1; end end
progress = "Processed 1000 background clips out of 4000"
progress = "Processed 2000 background clips out of 4000"
progress = "Processed 3000 background clips out of 4000"
progress = "Processed 4000 background clips out of 4000"
Xbkg = log10(Xbkg + epsil);
バックグラウンド ノイズのスペクトログラムを学習セット、検証セット、およびテスト セットに分割します。background noise フォルダーには約 5 分半のバックグラウンド ノイズのみが含まれているため、異なるデータ セットのバックグラウンド サンプルには高い相関があります。バックグラウンド ノイズのバリエーションを増やすために、独自のバックグラウンド ファイルを作成して、このフォルダーに追加できます。ノイズに対するネットワークのロバスト性を向上させるために、バックグラウンド ノイズを音声ファイルにミキシングしてみることもできます。
numTrainBkg = floor(0.85*numBkgClips); numValidationBkg = floor(0.15*numBkgClips); XTrain(:,:,:,end+1:end+numTrainBkg) = Xbkg(:,:,:,1:numTrainBkg); TTrain(end+1:end+numTrainBkg) = "background"; XValidation(:,:,:,end+1:end+numValidationBkg) = Xbkg(:,:,:,numTrainBkg+1:end); TValidation(end+1:end+numValidationBkg) = "background";
学習セットと検証セット内のさまざまなクラス ラベルの分布をプロットします。
figure(Units="normalized",Position=[0.2 0.2 0.5 0.5]) tiledlayout(2,1) nexttile histogram(TTrain) title("Training Label Distribution") nexttile histogram(TValidation) title("Validation Label Distribution")
ニューラル ネットワーク アーキテクチャの定義
シンプルなネットワーク アーキテクチャを層の配列として作成します。畳み込み層とバッチ正規化層を使用します。最大プーリング層を使って特徴マップを "空間的に" (つまり、時間と周波数に関して) ダウンサンプリングします。入力の特徴マップを時間の経過と共にグローバルにプーリングする最後の最大プーリング層を追加します。これによって (近似的な) 時間並進不変性が入力スペクトログラムに課されるため、ネットワークは、音声の正確な時間的位置とは無関係に同じ分類を実行できます。また、グローバル プーリングによって、最後の全結合層のパラメーター数が大幅に減少します。ネットワークが学習データの特定の特徴を記憶する可能性を減らすために、最後の全結合層への入力に少量のドロップアウトを追加します。
このネットワークは、フィルターがほとんどない 5 つの畳み込み層しかないため小規模です。numF
によって畳み込み層のフィルターの数を制御します。ネットワークの精度を高めるために、畳み込み層、バッチ正規化層、および ReLU 層の同等なブロックを追加して、ネットワーク深さを大きくすることを試してください。numF
を増やして、畳み込みフィルターの数を増やしてみることもできます。
各クラスの損失の合計重みを等しくするために、各クラスの学習例の数に反比例するクラスの重みを使用します。ネットワークの学習に Adam オプティマイザーを使用する場合、学習アルゴリズムはクラスの重み全体の正規化に依存しません。
classes = categories(TTrain); classWeights = 1./countcats(TTrain); classWeights = classWeights'/mean(classWeights); numClasses = numel(categories(TTrain)); timePoolSize = ceil(numHops/8); dropoutProb = 0.2; numF = 12; layers = [ imageInputLayer([numHops numBands]) convolution2dLayer(3,numF,Padding="same") batchNormalizationLayer reluLayer maxPooling2dLayer(3,Stride=2,Padding="same") convolution2dLayer(3,2*numF,Padding="same") batchNormalizationLayer reluLayer maxPooling2dLayer(3,Stride=2,Padding="same") convolution2dLayer(3,4*numF,Padding="same") batchNormalizationLayer reluLayer maxPooling2dLayer(3,Stride=2,Padding="same") convolution2dLayer(3,4*numF,Padding="same") batchNormalizationLayer reluLayer convolution2dLayer(3,4*numF,Padding="same") batchNormalizationLayer reluLayer maxPooling2dLayer([timePoolSize,1]) dropoutLayer(dropoutProb) fullyConnectedLayer(numClasses) softmaxLayer classificationLayer(Classes=classes,ClassWeights=classWeights)];
ネットワークの学習
学習オプションを指定します。ミニバッチ サイズを 128 として Adam オプティマイザーを使用します。学習は 25 エポック行い、20 エポック後に学習率を 10 分の 1 に下げます。
miniBatchSize = 128; validationFrequency = floor(numel(TTrain)/miniBatchSize); options = trainingOptions("adam", ... InitialLearnRate=3e-4, ... MaxEpochs=25, ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... Plots="training-progress", ... Verbose=false, ... ValidationData={XValidation,TValidation}, ... ValidationFrequency=validationFrequency, ... LearnRateSchedule="piecewise", ... LearnRateDropFactor=0.1, ... LearnRateDropPeriod=20);
ネットワークに学習をさせます。GPU がない場合、ネットワークの学習に時間がかかる場合があります。
trainedNet = trainNetwork(XTrain,TTrain,layers,options);
学習済みネットワークの評価
学習セット (データ拡張なし) と検証セットに対するネットワークの最終精度を計算します。このデータセットではネットワークは非常に正確になります。ただし、学習データ、検証データ、およびテスト データの分布はどれも似ていて、必ずしも実際の環境を反映していません。この制限は特に unknown
カテゴリに当てはまります。このカテゴリには、少数の単語の発話しか含まれていません。
if speedupExample load("commandNet.mat","trainedNet"); end YValPred = classify(trainedNet,XValidation); validationError = mean(YValPred ~= TValidation); YTrainPred = classify(trainedNet,XTrain); trainError = mean(YTrainPred ~= TTrain); disp("Training error: " + trainError*100 + "%")
Training error: 1.5794%
disp("Validation error: " + validationError*100 + "%")
Validation error: 4.6692%
混同行列をプロットします。列と行の要約を使用して、各クラスの適合率と再現率を表示します。混同行列のクラスを並べ替えます。大きな認識の違いが現れるのは、未知の単語間、コマンド up と off、down と no、および go と no の間にあります。
figure(Units="normalized",Position=[0.2 0.2 0.5 0.5]); cm = confusionchart(TValidation,YValPred, ... Title="Confusion Matrix for Validation Data", ... ColumnSummary="column-normalized",RowSummary="row-normalized"); sortClasses(cm,[commands,"unknown","background"])
モバイル用途など、ハードウェア リソースに制約がある用途で使用する場合は、利用可能なメモリおよび計算リソースの制限を考慮します。CPU を使用する場合は、ネットワークの合計サイズを KB 単位で計算し、その予測速度をテストします。予測時間は 1 つの入力イメージの分類にかかる時間です。複数のイメージをネットワークに入力する場合、これらを同時に分類して、イメージあたりの予測時間を短くすることができます。ストリーミング オーディオを分類する場合、1 つのイメージの予測時間が最も重要です。
info = whos("trainedNet"); disp("Network size: " + info.bytes/1024 + " kB")
Network size: 292.2139 kB
time = zeros(100,1); for ii = 1:100 x = randn([numHops,numBands]); tic [YPredicted,probs] = classify(trainedNet,x,ExecutionEnvironment="cpu"); time(ii) = toc; end disp("Single-image prediction time on CPU: " + mean(time(11:end))*1000 + " ms")
Single-image prediction time on CPU: 2.4838 ms
参考文献
[1] Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. Speech Commands Dataset は、次で公開されている Creative Commons Attribution 4.0 license に従ってライセンスされています。https://creativecommons.org/licenses/by/4.0/legalcode。
参照
[1] Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. The Speech Commands Dataset is licensed under the Creative Commons Attribution 4.0 license, available here: https://creativecommons.org/licenses/by/4.0/legalcode.
参考
trainNetwork
| classify
| analyzeNetwork