メインコンテンツ

事前学習済みオーディオ ネットワークを使用した転移学習

この例では、転移学習を使用して、事前学習済みの畳み込みニューラル ネットワークである YAMNet を再学習させ、新しいオーディオ信号セットを分類する方法を示します。オーディオ深層学習を最初から始めるには、Classify Sound Using Deep Learning (Audio Toolbox)を参照してください。

転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。通常は、転移学習によってネットワークを微調整する方が、ランダムに初期化された重みでゼロからネットワークに学習させるよりもはるかに簡単で時間がかかりません。より少ない数の学習信号を使用して、学習済みの特徴を新しいタスクに高速に転移できます。

pretainedworkflow.png

Audio Toolbox™ はさらに、YAMNet に必要な前処理と結果を解釈するための便利な後処理を実装するclassifySound (Audio Toolbox)関数も提供します。Audio Toolbox は、事前学習済みの VGGish ネットワーク (vggish (Audio Toolbox)) と、VGGish ネットワークの前処理と後処理を実装するvggishEmbeddings (Audio Toolbox)関数も提供します。

データの作成

100 個のホワイト ノイズ信号、100 個のブラウン ノイズ信号、100 個のピンク ノイズ信号を生成します。各信号は、16 kHz のサンプル レートを想定した場合、0.98 秒の持続時間を表します。

fs = 16e3;
duration = 0.98;
N = duration*fs;
numSignals = 100;

wNoise = 2*rand([N,numSignals]) - 1;
wLabels = repelem(categorical("white"),numSignals,1);

bNoise = filter(1,[1,-0.999],wNoise);
bNoise = bNoise./max(abs(bNoise),[],"all");
bLabels = repelem(categorical("brown"),numSignals,1);

pNoise = pinknoise([N,numSignals]);
pLabels = repelem(categorical("pink"),numSignals,1);

データを学習セットとテスト セットに分割します。通常、学習セットはデータの大部分で構成されます。ただし、転移学習の効果を実証するため、学習には少数のサンプルのみを使用し、大部分を検証に使用します。

K = 5;

trainAudio = [wNoise(:,1:K),bNoise(:,1:K),pNoise(:,1:K)];
trainLabels = [wLabels(1:K);bLabels(1:K);pLabels(1:K)];

validationAudio = [wNoise(:,K+1:end),bNoise(:,K+1:end),pNoise(:,K+1:end)];
validationLabels = [wLabels(K+1:end);bLabels(K+1:end);pLabels(K+1:end)];

fprintf("Number of samples per noise color in train set = %d\n" + ...
        "Number of samples per noise color in validation set = %d\n",K,numSignals-K);
Number of samples per noise color in train set = 5
Number of samples per noise color in validation set = 95

特徴の抽出

yamnetPreprocess (Audio Toolbox)を使用し、YAMNet モデルの学習時に使用したのと同じパラメーターを使って、学習セットと検証セットの両方から log-mel スペクトログラムを抽出します。

trainFeatures = yamnetPreprocess(trainAudio,fs);
validationFeatures = yamnetPreprocess(validationAudio,fs);

転移学習

事前学習済みのネットワークを読み込むには、yamnet (Audio Toolbox)を呼び出します。YAMNet 用の Audio Toolbox モデルがインストールされていない場合、この関数はネットワークの重みファイルの場所へのリンクを提供します。モデルをダウンロードするには、リンクをクリックします。MATLAB パス上の場所にファイルを解凍します。YAMNet モデルは、ホワイト ノイズやピンク ノイズ (ブラウン ノイズは除く) を含む 521 種類のサウンド カテゴリのいずれかにオーディオを分類できます。

net = yamnet;
net.Layers(end).Classes
ans = 521×1 categorical
     Speech 
     Child speech, kid speaking 
     Conversation 
     Narration, monologue 
     Babbling 
     Speech synthesizer 
     Shout 
     Bellow 
     Whoop 
     Yell 
     Children shouting 
     Screaming 
     Whispering 
     Laughter 
     Baby laughter 
     Giggle 
     Snicker 
     Belly laugh 
     Chuckle, chortle 
     Crying, sobbing 
     Baby cry, infant cry 
     Whimper 
     Wail, moan 
     Sigh 
     Singing 
     Choir 
     Yodeling 
     Chant 
     Mantra 
     Child singing 
      ⋮

まずネットワークをlayerGraphに変換して、転移学習用のモデルを準備します。全結合層を未学習の全結合層に置き換えるには、replaceLayerを使用します。分類層を、入力を "white"、"pink"、または "brown" に分類する分類層に置き換えます。MATLAB® でサポートされている深層学習層については、深層学習層の一覧を参照してください。

uniqueLabels = unique(trainLabels);
numLabels = numel(uniqueLabels);

lgraph = layerGraph(net.Layers);

lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,Name="dense"));
lgraph = replaceLayer(lgraph,"Sound",classificationLayer(Name="Sounds",Classes=uniqueLabels));

学習オプションを定義するには、trainingOptionsを使用します。

options = trainingOptions("adam",ValidationData={single(validationFeatures),validationLabels});

ネットワークに学習させるには、trainNetworkを使用します。このネットワークは、ノイズ タイプごとにわずか 5 つの信号を使用して 100% の検証精度を実現しています。

trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:01 |       20.00% |       88.42% |       1.1922 |       0.6651 |          0.0010 |
|      30 |          30 |       00:00:14 |      100.00% |      100.00% |   5.0068e-06 |       0.0003 |          0.0010 |
|======================================================================================================================|
Training finished: Max epochs completed.