ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

ウェーブレット時間散乱を使用した音楽ジャンルの分類

この例では、ウェーブレット時間散乱とオーディオ データストアを使用して、音楽の抜粋のジャンルを分類する方法を示します。ウェーブレット散乱では、データは、データの低分散の表現を作成するために一連のウェーブレット変換、非線形性、および平均化全体に伝播されます。これらの低分散の表現は、その後分類器で入力として使用されます。

GTZAN データセット

この例で使用されたデータセットは GTZAN ジャンル コレクションです [7][8]。データは約 1.2 GB の圧縮された tar アーカイブとして提供されます。非圧縮データセットには約 3 GB のディスク領域が必要です。上記に示したリンクから圧縮された tar ファイルを抽出して、10 個のサブフォルダーをもつフォルダーを作成します。各サブフォルダーは含まれる音楽サンプルのジャンルの名前が付けられます。ジャンルは、ブルース、クラシック、カントリー、ディスコ、ヒップホップ、ジャズ、メタル、ポップ、レゲエ、およびロックです。各ジャンルに 100 例あり、各オーディオ ファイルは 22050 Hz でサンプリングされた約 30 秒のデータで構成されます。最初の論文では、著者は、61% の精度を達成するために各音楽の例から抽出したメル周波数ケプストラム (MFC) 係数と混合ガウス モデル (GMM) 分類を含む時間領域と周波数領域の特徴の数値を使用しました [7]。その後、深層学習ネットワークがこのデータに適用されました。ほとんどの場合、これらの深層学習方法は、深層 CNN への入力として MFC 係数またはスペクトログラムを使用した畳み込みニューラル ネットワーク (CNN) で構成されます。これらの方法の結果、パフォーマンスは約 84% になりました [4]。スペクトログラム時間のスライスを使用した LSTM 方法は精度が 79% となり、アンサンブル学習方法 (AdaBoost) と一体になった時間領域と周波数領域の特徴はテスト セットで精度が 82% になりました [2][3]。最近、スパース表現機械学習方法は約 89% の精度を達成しました [6]。

ウェーブレット散乱フレームワーク

ウェーブレット時間散乱フレームワークで指定する唯一のパラメーターは、時間不変性の持続時間、ウェーブレット フィルター バンクの数、およびオクターブあたりのウェーブレットの数になります。ほとんどのアプリケーションでは、2 つのウェーブレット フィルター バンクからのデータのカスケードで十分です。この例では、2 つのウェーブレット フィルター バンクを使用する既定の散乱フレームワークを使用します。最初のフィルター バンクはオクターブあたり 8 つのウェーブレット、2 つ目のフィルター バンクはオクターブあたり 1 つのウェーブレットです。この例の場合、不変スケールを 0.5 秒になるように設定します。これは、指定したサンプリング レートの 11,000 をわずかに超えるサンプルに対応します。ウェーブレット時間散乱分解フレームワークを作成します。

sf = waveletScattering('SignalLength',2^19,'SamplingFrequency',22050,...
    'InvarianceScale',0.5);

不変スケールの役割を理解するには、最初のフィルター バンクから最も粗いスケール ウェーブレットの実数部と虚数部と一緒にスケーリング フィルターを取得して時間でプロットします。スケーリング フィルターの時間サポートが設計したように基本的に 0.5 秒になることに注意してください。さらに、最も粗いスケール ウェーブレットの時間サポートは、ウェーブレット散乱分解の不変スケールを超えません。

[fb,f,filterparams] = filterbank(sf);
phi = ifftshift(ifft(fb{1}.phift));
psiL1 = ifftshift(ifft(fb{2}.psift(:,end)));
dt = 1/22050;
time = -2^18*dt:dt:2^18*dt-dt;
scalplt = plot(time,phi,'linewidth',1.5);
hold on
grid on
ylimits = [-3e-4 3e-4];
ylim(ylimits);
plot([-0.25 -0.25],ylimits,'k--');
plot([0.25 0.25],ylimits,'k--');
xlim([-0.6 0.6]);
xlabel('Seconds'); ylabel('Amplitude');
wavplt = plot(time,[real(psiL1) imag(psiL1)]);
legend([scalplt wavplt(1) wavplt(2)],{'Scaling Function','Wavelet-Real Part','Wavelet-Imaginary Part'});
title({'Scaling Function';'Coarsest-Scale Wavelet First Filter Bank'})
hold off

オーディオ データストア

オーディオ データストアによりオーディオ データ ファイルの収集の管理が可能になります。機械学習または深層学習の場合、オーディオ データストアは、ファイルとフォルダーからのオーディオ データのフローの管理だけでなく、データとラベルの関連付けも管理し、データをトレーニング、検証、およびテストの異なる設定にランダムに分割する機能を提供します。この例では、オーディオ データストアを使用して、GTZAN 音楽のジャンル コレクションを管理します。コレクションの各サブフォルダーはジャンルを表す名前が付けられるということを思い出してください。サブフォルダーを使用するためにオーディオ データストアを構築するには 'IncludeSubFolders' プロパティを true に設定し、サブフォルダー名を基にしたデータ ラベルを作成するには 'LabelSource' プロパティを 'foldernames' に設定します。この例では、トップレベルのディレクトリは MATLAB tempdir ディレクトリ内にあると想定し、'ジャンル' と呼ばれます。location がマシン上のトップレベルのデータ フォルダーへの正しいパスを示していることを確認します。マシン上のトップレベルのデータ フォルダーは、10 個のジャンルにそれぞれ名前が付いた 10 個のサブフォルダーを含み、それらのジャンルに対応するオーディオ ファイルのみ含まなければなりません。

location = fullfile(tempdir,'genres');
ads = audioDatastore(location,'IncludeSubFolders',true,...
    'LabelSource','foldernames');

次を実行して、データセットの音楽ジャンル数を取得します。

countEachLabel(ads)
ans =

  10×2 table

      Label      Count
    _________    _____

    blues         100 
    classical     100 
    country       100 
    disco         100 
    hiphop        100 
    jazz          100 
    metal         100 
    pop           100 
    reggae        100 
    rock          100 

前述のとおり、ジャンルが 10 個とそれぞれにファイルが 100 個あります。

学習セットとテスト セット

学習セットとテスト セットを作成し、分類器を開発してテストします。学習にはデータの 80% を使用して、テストには残りの 20% をホールドアウトします。オーディオ データストアの関数 shuffle はランダムにデータをシャッフルします。データをランダム化するにはラベルによってデータを分割する前にこれを実行します。この例では、再現性をもたせるために乱数発生器のシードを設定します。オーディオ データストア関数 splitEachLabel を使用して、80-20 分割を実行します。splitEachLabel は、すべてのクラスが均等に表されることを確認します。

rng(100);
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
countEachLabel(adsTest)
ans =

  10×2 table

      Label      Count
    _________    _____

    blues         80  
    classical     80  
    country       80  
    disco         80  
    hiphop        80  
    jazz          80  
    metal         80  
    pop           80  
    reggae        80  
    rock          80  


ans =

  10×2 table

      Label      Count
    _________    _____

    blues         20  
    classical     20  
    country       20  
    disco         20  
    hiphop        20  
    jazz          20  
    metal         20  
    pop           20  
    reggae        20  
    rock          20  

学習データには予想どおり 800 個の記録が、テスト データには 200 個の記録があることがわかります。さらに、学習セットには各ジャンルの 80 個の例が、テスト セットには各ジャンルの 20 個の例があります。

audioDatastore は MATLAB tall 配列で有効です。学習セットとテスト セットの両方に tall 配列を作成します。システムに応じて、並列プール MATLAB が作成するワーカー数は異なる場合があります。

Ttrain = tall(adsTrain);
Ttest = tall(adsTest);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

散乱特徴量を取得するには、補助関数 helperscatfeatures を決定します。この関数で、各オーディオ ファイルの 2^19 サンプルと散乱時間枠の数を 8 で割ったサブサンプルに対する散乱特徴量の自然対数を取得します。helperscatfeatures のソース コードの一覧を付録に示します。学習データとテスト データの両方のウェーブレット散乱特徴量を計算します。

scatteringTrain = cellfun(@(x)helperscatfeatures(x,sf),Ttrain,'UniformOutput',false);
scatteringTest = cellfun(@(x)helperscatfeatures(x,sf),Ttest,'UniformOutput',false);

学習データの散乱特徴量を計算して行列にすべての特徴量をまとめます。この処理には数分かかります。

TrainFeatures = gather(scatteringTrain);
TrainFeatures = cell2mat(TrainFeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 8 min 9 sec
Evaluation completed in 8 min 9 sec

テスト データに対してこの処理を繰り返します。

TestFeatures = gather(scatteringTest);
TestFeatures = cell2mat(TestFeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 2 min 6 sec
Evaluation completed in 2 min 6 sec

TrainFeaturesTestFeatures の各行は各オーディオ信号の散乱変換で 341 個のパス全体にまたがる 1 つの散乱時間枠です。各音楽サンプルの場合、32 個の時間枠があります。したがって、学習データの特徴量行列は 25600 x 341 です。行数は、学習例の数 (800) と例あたりの散乱時間枠数 (32) の積と等しくなります。同様に、テスト データの散乱特徴量行列は 6400 x 341 です。200 個のテスト例と例あたり 32 個の枠があります。学習データ用 32 個のウェーブレット散乱特徴量行列の枠のそれぞれにジャンル ラベルを作成します。

numTimeWindows = 32;
trainLabels = adsTrain.Labels;
numTrainSignals = numel(trainLabels);
trainLabels = repmat(trainLabels,1,numTimeWindows);
trainLabels = reshape(trainLabels',numTrainSignals*numTimeWindows,1);

テスト データに対して手順を繰り返します。

testLabels = adsTest.Labels;
numTestSignals = numel(testLabels);
testLabels = repmat(testLabels,1,numTimeWindows);
testLabels = reshape(testLabels',numTestSignals*numTimeWindows,1);

この例では、3 次多項式カーネルをもつマルチクラスのサポート ベクター マシン (SVM) 分類器を使用します。SVM を学習データに適合させます。

template = templateSVM(...
    'KernelFunction', 'polynomial', ...
    'PolynomialOrder', 3, ...
    'KernelScale', 'auto', ...
    'BoxConstraint', 1, ...
    'Standardize', true);
Classes = {'blues','classical','country','disco','hiphop','jazz',...
    'metal','pop','reggae','rock'};
classificationSVM = fitcecoc(...
    TrainFeatures, ...
    trainLabels, ...
    'Learners', template, ...
    'Coding', 'onevsone','ClassNames',categorical(Classes));

テスト セットの予測

学習データの散乱変換に適合させる SVM モデルを使用して、テスト データの音楽ジャンルを予測します。散乱変換の各信号に 32 個の時間枠があることを思い出してください。簡単な多数決を使用して、ジャンルを予測します。補助関数 helperMajorityVote は、32 個の散乱時間枠すべてにわたるジャンル ラベルのモードを取得します。一意のモードがない場合、helperMajorityVote は、'NoUniqueMode' によって指定された分類誤差を返します。これは、混同行列の余分の列になります。helperMajorityVote のソース コードの一覧を付録に示します。

predLabels = predict(classificationSVM,TestFeatures);
[TestVotes,TestCounts] = helperMajorityVote(predLabels,adsTest.Labels,categorical(Classes));
testAccuracy = sum(eq(TestVotes,adsTest.Labels))/numTestSignals*100;

テスト精度、testAccuracy は 88% です。この精度は、最先端の GTZAN データセットと同様です。

ジャンル x ジャンルの精度率を検証するために混同行列を表示します。各クラスに 20 個の例があることを思い出してください。

confusionchart(TestVotes,adsTest.Labels);

混同行列プロットの対角は、個々のジャンルの分類精度が一般的に非常に良好であることを示します。これらのジャンル精度を抽出して、個別にプロットします。

cm = confusionmat(TestVotes,adsTest.Labels);
cm(:,end) = [];
genreAccuracy = diag(cm)./20*100;
figure;
bar(genreAccuracy)
set(gca,'XTickLabels',Classes);
xtickangle(gca,30);
title('Percentage Correct by Genre - Test Set');

まとめ

この例は、音楽ジャンル分類でのウェーブレット時間散乱とオーディオ データストアの使用を示しています。この例では、ウェーブレット時間散乱は、GTZAN データセットの最先端の性能と同程度の分類精度を達成しました。多くの時間領域と周波数領域の特徴量の抽出を必要とする他の方法と異なり、ウェーブレット散乱には単一のパラメーターの仕様である、時不変スケールのみが必要でした。オーディオ データストアによって、ディスクから MATLAB への大規模なデータセットの変換を効率的に管理することができ、データをランダム化して分類ワークフローからランダム化されたデータのジャンル メンバーシップを正確に保持することが許可されました。

参考文献

  1. Anden, J. and Mallat, S. 2014. Deep scattering spectrum.IEEE Transactions on Signal Processing, Vol. 62, 16, pp. 4114-4128.

  2. Bergstra, J., Casagrande, N., Erhan, D., Eck, D., and Kegl, B. Aggregate features and AdaBoost for music classification.Machine Learning, Vol. 65, Issue 2-3, pp. 473-484.

  3. Irvin, J., Chartock, E., and Hollander, N. 2016. Recurrent neural networks with attention for genre classification. https://www.semanticscholar.org/paper/Recurrent-Neural-Networks-with-Attention-for-Genre-Irvin-Chartock/bff3eaf5d8ebb6e613ae0146158b2b5346ee7323

  4. Li, T., Chan, A.B., and Chun, A. 2010. Automatic musical pattern feature extraction using convolutional neural network.International Conference Data Mining and Applications.

  5. Mallat.S. 2012. Group invariant scattering.Communications on Pure and Applied Mathematics, Vol. 65, 10, pp. 1331-1398.

  6. Panagakis, Y., Kotropoulos, C.L., and Arce, G.R.2014.Music genre classification via joint sparse low-rank representation of audio features.IEEE Transactions on Audio, Speech, and Language Processing, 22, 12, pp. 1905-1917.

  7. Tzanetakis, G. and Cook, P. 2002. Music genre classification of audio signals.IEEE Transactions on Speech and Audio Processing, Vol. 10, No. 5, pp. 293-302.

  8. GTZAN ジャンル コレクションhttp://marsyas.info/downloads/datasets.html

付録 -- サポート関数

helperMajorityVote -- この関数は、多数の特徴ベクトル間で予測されたクラス ラベルのモードを返します。ウェーブレット時間散乱で、各時間枠のクラス ラベルを取得します。固有のモードが見つからない場合、分類誤差を示す 'NoUniqueMode' のラベルが返されます。

function [ClassVotes,ClassCounts] = helperMajorityVote(predLabels,origLabels,classes)
% This function is in support of wavelet scattering examples only. It may
% change or be removed in a future release.

% Make categorical arrays if the labels are not already categorical
predLabels = categorical(predLabels);
origLabels = categorical(origLabels);
% Expects both predLabels and origLabels to be categorical vectors
Npred = numel(predLabels);
Norig = numel(origLabels);
Nwin = Npred/Norig;
predLabels = reshape(predLabels,Nwin,Norig);
ClassCounts = countcats(predLabels);
[mxcount,idx] = max(ClassCounts);
ClassVotes = classes(idx);
% Check for any ties in the maximum values and ensure they are marked as
% error if the mode occurs more than once
modecnt = modecount(ClassCounts,mxcount);
ClassVotes(modecnt>1) = categorical({'NoUniqueMode'});
ClassVotes = ClassVotes(:);

%-------------------------------------------------------------------------
    function modecnt = modecount(ClassCounts,mxcount)
        modecnt = Inf(size(ClassCounts,2),1);
        for nc = 1:size(ClassCounts,2)
            modecnt(nc) = histc(ClassCounts(:,nc),mxcount(nc));
        end
    end
end

helperscatfeatures - この関数は、指定した入力信号のウェーブレット時間散乱の特徴量行列を返します。この場合、ウェーブレット散乱係数の自然対数を使用します。散乱特徴量行列は信号の 2^19 サンプルで計算されます。散乱特徴量は係数 8 でサブサンプリングされます。

function features = helperscatfeatures(x,sf)
% This function is in support of wavelet scattering examples only. It may
% change or be removed in a future release.

features = featureMatrix(sf,x(1:2^19),'Transform','log');
features = features(:,1:8:end)';
end