メインコンテンツ

事前学習済みオーディオ ネットワークを使用した転移学習

この例では、転移学習を使用して、事前学習済みの畳み込みニューラル ネットワークである YAMNet を再学習させ、新しいオーディオ信号セットを分類する方法を示します。オーディオ深層学習を最初から始めるには、Classify Sound Using Deep Learning (Audio Toolbox)を参照してください。

転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。通常は、転移学習によってネットワークを微調整する方が、ランダムに初期化された重みでゼロからネットワークに学習させるよりもはるかに簡単で時間がかかりません。より少ない数の学習信号を使用して、学習済みの特徴を新しいタスクに高速に転移できます。

pretainedworkflow.png

Audio Toolbox™ はさらに、YAMNet に必要な前処理と結果を解釈するための便利な後処理を実装するclassifySound (Audio Toolbox)関数も提供します。Audio Toolbox は、事前学習済みの VGGish ネットワーク (vggish (Audio Toolbox)) と、VGGish ネットワークの前処理と後処理を実装するvggishEmbeddings (Audio Toolbox)関数も提供します。

データの作成

100 個のホワイト ノイズ信号、100 個のブラウン ノイズ信号、100 個のピンク ノイズ信号を生成します。各信号は、16 kHz のサンプル レートを想定した場合、0.98 秒の持続時間を表します。

fs = 16e3;
duration = 0.98;
N = duration*fs;
numSignals = 100;

wNoise = 2*rand([N,numSignals]) - 1;
wLabels = repelem(categorical("white"),numSignals,1);

bNoise = filter(1,[1,-0.999],wNoise);
bNoise = bNoise./max(abs(bNoise),[],"all");
bLabels = repelem(categorical("brown"),numSignals,1);

pNoise = pinknoise([N,numSignals]);
pLabels = repelem(categorical("pink"),numSignals,1);

データを学習セットとテスト セットに分割します。通常、学習セットはデータの大部分で構成されます。ただし、転移学習の効果を実証するため、学習には少数のサンプルのみを使用し、大部分を検証に使用します。

K = 5;

trainAudio = [wNoise(:,1:K),bNoise(:,1:K),pNoise(:,1:K)];
trainLabels = [wLabels(1:K);bLabels(1:K);pLabels(1:K)];

validationAudio = [wNoise(:,K+1:end),bNoise(:,K+1:end),pNoise(:,K+1:end)];
validationLabels = [wLabels(K+1:end);bLabels(K+1:end);pLabels(K+1:end)];

fprintf("Number of samples per noise color in train set = %d\n" + ...
        "Number of samples per noise color in validation set = %d\n",K,numSignals-K);
Number of samples per noise color in train set = 5
Number of samples per noise color in validation set = 95

特徴の抽出

yamnetPreprocess (Audio Toolbox)を使用し、YAMNet モデルの学習時に使用したのと同じパラメーターを使って、学習セットと検証セットの両方から log-mel スペクトログラムを抽出します。

trainFeatures = yamnetPreprocess(trainAudio,fs);
validationFeatures = yamnetPreprocess(validationAudio,fs);

転移学習

事前学習済みのネットワークを読み込むには、yamnet (Audio Toolbox)を呼び出します。YAMNet 用の Audio Toolbox モデルがインストールされていない場合、この関数はネットワークの重みファイルの場所へのリンクを提供します。モデルをダウンロードするには、リンクをクリックします。MATLAB パス上の場所にファイルを解凍します。YAMNet モデルは、ホワイト ノイズやピンク ノイズ (ブラウン ノイズは除く) を含む 521 種類のサウンド カテゴリのいずれかにオーディオを分類できます。

net = yamnet;
net.Layers(end).Classes
ans = 521×1 categorical array
    "Speech"
    "Child speech, kid speaking"
    "Conversation"
    "Narration, monologue"
    "Babbling"
    "Speech synthesizer"
    "Shout"
    "Bellow"
    "Whoop"
    "Yell"
      ⋮

まずネットワークをlayerGraphに変換して、転移学習用のモデルを準備します。全結合層を未学習の全結合層に置き換えるには、replaceLayerを使用します。分類層を、入力を "white"、"pink"、または "brown" に分類する分類層に置き換えます。MATLAB® でサポートされている深層学習層については、深層学習層の一覧を参照してください。

uniqueLabels = unique(trainLabels);
numLabels = numel(uniqueLabels);

lgraph = layerGraph(net.Layers);

lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,Name="dense"));
lgraph = replaceLayer(lgraph,"Sound",classificationLayer(Name="Sounds",Classes=uniqueLabels));

学習オプションを定義するには、trainingOptionsを使用します。

options = trainingOptions("adam",ValidationData={single(validationFeatures),validationLabels});

ネットワークに学習させるには、trainNetworkを使用します。このネットワークは、ノイズ タイプごとにわずか 5 つの信号を使用して 100% の検証精度を実現しています。

trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:01 |       20.00% |       88.42% |       1.1922 |       0.6651 |          0.0010 |
|      30 |          30 |       00:00:14 |      100.00% |      100.00% |   5.0068e-06 |       0.0003 |          0.0010 |
|======================================================================================================================|
Training finished: Max epochs completed.