Main Content

音声合成の敵対的生成ネットワーク (GAN) の学習

この例では、敵対的生成ネットワーク (GAN) に学習させ、そのネットワークを使用して音声を生成する方法を説明します。

はじめに

敵対的生成ネットワークでは、ジェネレーターとディスクリミネーターが互いに競い合うことで、生成品質を向上させます。

GAN は、オーディオ処理および音声処理の分野で高い関心を集めました。応用例としては、テキスト音声合成、音声変換、音声強調などがあります。

この例では、GAN の教師なし学習を実行してオーディオ波形を合成します。この例の GAN はドラムビート音を生成します。話し声など、他の種類の音も、同じ方法で生成できます。

事前学習済みの GAN を使用したオーディオ合成

GAN にゼロから学習させる前に、事前学習済みの GAN ジェネレーターを使用してドラムビートを合成します。

事前学習済みのジェネレーターをダウンロードします。

matFileName = "drumGeneratorWeights.mat";
loc = matlab.internal.examples.downloadSupportFile("audio","GanAudioSynthesis/" + matFileName);
copyfile(loc,pwd)

関数 synthesizeDrumBeat が、事前学習済みのネットワークを呼び出して、16 kHz でサンプリングされたドラムビートを合成します。関数 synthesizeDrumBeat は、この例の最後に示されています。

ドラムビートを合成して再生します。

drum = synthesizeDrumBeat;

fs = 16e3;
sound(drum,fs)

合成されたドラムビートをプロットします。

t = (0:length(drum)-1)/fs;
plot(t,drum)
grid on
xlabel("Time (s)")
title("Synthesized Drum Beat")

ドラムビート シンセサイザーを使用して他のオーディオ エフェクトを適用することで、より複雑なアプリケーションを作成できます。たとえば、合成されたドラムビートに残響を適用できます。

reverberator (Audio Toolbox)オブジェクトを作成し、そのパラメーター チューナー UI を開きます。シミュレーション実行時、この UI を使用して reverberator パラメーターを調整できます。

reverb = reverberator(SampleRate=fs);
parameterTuner(reverb);

ドラム ビートを可視化する time scope オブジェクトを作成します。

ts = timescope(SampleRate=fs, ...
    TimeSpanSource="Property", ...
    TimeSpanOverrunAction="Scroll", ...
    TimeSpan=10, ...
    BufferLength=10*256*64, ...
    ShowGrid=true, ...
    YLimits=[-1 1]);

ループ内で、ドラムビートを合成して残響を適用します。パラメーター チューナー UI を使用して残響を調整します。より長い時間シミュレーションを実行する場合は、loopCount パラメーターの値を大きくします。

loopCount = 20;
for ii = 1:loopCount
    drum = synthesizeDrumBeat;
    drum = reverb(drum);
    ts(drum(:,1));
    soundsc(drum,fs)
    pause(0.5)
end

GAN の学習

事前学習済みのドラムビート ジェネレーターの動作を確認できたので、次に、学習プロセスを詳しく見ていきます。

GAN は深層学習ネットワークの一種で、学習データと類似した特性をもつデータを生成します。

GAN は同時に学習させることが可能な 2 つのネットワーク、すなわち、"ジェネレーター" と "ディスクリミネーター" で構成されています。

  • ジェネレーター。ベクトルまたはランダムな値を入力として与えられ、このネットワークは学習データと同じ構造のデータを生成します。ディスクリミネーターを騙すのがジェネレーターの仕事です。

  • ディスクリミネーター。学習データと生成されたデータの両方の観測値を含むデータのバッチを与えられ、このネットワークはその観測値が実データか生成データかの分類を試みます。

ジェネレーターのパフォーマンスを最大化するには、生成されたデータが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが実データと分類するようなデータを生成することです。ディスクリミネーターのパフォーマンスを最大化するには、実データと生成データ両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。これらの方法によって、いかにも本物らしいデータを生成するジェネレーターと、学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。

この例では、ドラムビートを表す偽の時間-周波数短時間フーリエ変換 (STFT) を作成するようにジェネレーターに学習させます。本物の STFT を識別するようにディスクリミネーターに学習させます。短時間録音した実際のドラムビートの STFT を計算し、本物の STFT を作成します。

学習データの読み込み

Drum Sound Effects データセット [1] で GAN に学習させます。データセットをダウンロードし、解凍します。

url = "http://deepyeti.ucsd.edu/cdonahue/wavegan/data/drums.tar.gz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"drums_dataset.tgz");

drumsFolder = fullfile(downloadFolder,"drums");
if ~datasetExists(drumsFolder)
    disp("Downloading Drum Sound Effects Dataset (218 MB) ...")
    websave(filename,url);
    untar(filename,downloadFolder)
end
Downloading Drum Sound Effects Dataset (218 MB) ...

ドラム データセットを指すaudioDatastore (Audio Toolbox)オブジェクトを作成します。

ads = audioDatastore(drumsFolder,IncludeSubfolders=true);

ジェネレーター ネットワークの定義

乱数値の 1 x 1 x 100 の配列から STFT を生成するネットワークを定義します。全結合層、およびそれに続く形状変更層と、ReLU 層をもつ一連の転置畳み込み層を使用して、1 x 1 x 100 の配列を 128 x 128 x 1 の配列にアップスケールするネットワークを作成します。

この図は、ジェネレーターを通過するときの信号の次元を示しています。ジェネレーターのアーキテクチャは [1] の表 4 で定義されています。

ジェネレーター ネットワークは、この例の最後に示されている modelGenerator で定義されています。

ディスクリミネーター ネットワークの定義

128 x 128 の STFT の実イメージと生成イメージを分類するネットワークを定義します。

128 x 128 のイメージを受け取り、leaky ReLU 層をもつ一連の畳み込み層とそれに続く全結合層を使用してスカラーの予測スコアを出力するネットワークを作成します。

この図は、ディスクリミネーターを通過するときの信号の次元を示しています。このディスクリミネーターのアーキテクチャは [1] の表 5 で定義されています。

このディスクリミネーター ネットワークは、この例の最後に示されている modelDiscriminator で定義されています。

本物のドラムビートの学習データの生成

データストアにあるドラムビート信号から STFT データを生成します。

STFT パラメーターを定義します。

fftLength = 256;
win = hann(fftLength,"periodic");
overlapLength = 128;

処理を高速化するために、parforを使用して複数のワーカーに特徴抽出を分散します。

最初に、データセットの区画数を決定します。Parallel Computing Toolbox™ がない場合は、単一の区画を使用します。

if ~isempty(ver("parallel"))
    pool = gcp;
    numPar = numpartitions(ads,pool);
else
    numPar = 1;
end

区画ごとにデータストアから読み取り、STFT を計算します。

parfor ii = 1:numPar

    subds = partition(ads,numPar,ii);
    STrain = zeros(fftLength/2+1,128,1,numel(subds.Files));
    
    for idx = 1:numel(subds.Files)
        
        x = read(subds);
        
        if length(x) > fftLength*64 
            % Lengthen the signal if it is too short
            x = x(1:fftLength*64);
        end
        
        % Convert from double-precision to single-precision
        x = single(x);
        
        % Scale the signal
        x = x ./ max(abs(x));
        
        % Zero-pad to ensure stft returns 128 windows.
        x = [x;zeros(overlapLength,1,"like",x)];
        
        S0 = stft(x,Window=win,OverlapLength=overlapLength,Centered=false);
        
        % Convert from two-sided to one-sided.
        S = S0(1:129,:);
        S = abs(S);
        STrain(:,:,:,idx) = S;
    end
    STrainC{ii} = STrain;
end

4 番目の次元に STFT をもつ 4 次元配列に出力を変換します。

STrain = cat(4,STrainC{:});

人間の知覚に合わせるため、データを対数スケールに変換します。

STrain = log(STrain + 1e-6);

平均値が 0、標準偏差が 1 となるように学習データを正規化します。

各周波数ビンにおける STFT の平均値と標準偏差を計算します。

SMean = mean(STrain,[2 3 4]);
SStd = std(STrain,1,[2 3 4]);

各周波数ビンを正規化します。

STrain = (STrain-SMean)./SStd;

計算された STFT は非有界の値をもちます。[1] の方法に従い、スペクトルを 3 の標準偏差にクリッピングしてから [-1 1] に再スケーリングし、データを有界にします。

STrain = STrain/3;
Y = reshape(STrain,numel(STrain),1);
Y(Y<-1) = -1;
Y(Y>1) = 1;
STrain = reshape(Y,size(STrain));

STFT ビンの数が 2 のべき乗となるように、最後の周波数ビンを破棄します (畳み込み層に適しています)。

STrain = STrain(1:end-1,:,:,:);

ディスクリミネーターに入力できるように、次元を並べ替えます。

STrain = permute(STrain,[2 1 3 4]);

学習オプションの指定

ミニバッチ サイズを 64 として 1000 エポック学習させます。

maxEpochs = 1000;
miniBatchSize = 64;

データを使い切るのに必要な反復回数を計算します。

numIterationsPerEpoch = floor(size(STrain,4)/miniBatchSize);

Adam 最適化のオプションを指定します。ジェネレーターとディスクリミネーターの学習率を 0.0002 に設定します。どちらのネットワークも、勾配の減衰係数に 0.5 を、2 乗勾配の減衰係数に 0.999 を使用します。

learnRateGenerator = 0.0002;
learnRateDiscriminator = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには Parallel Computing Toolbox™ が必要です。

executionEnvironment = "auto";

ジェネレーターとディスクリミネーターの重みを初期化します。関数 initializeGeneratorWeights および initializeDiscriminatorWeights は、Glorot の一様分布による初期化を使用して得られたランダムな重みを返します。この関数は、この例の最後に示されています。

generatorParameters = initializeGeneratorWeights;
discriminatorParameters = initializeDiscriminatorWeights;

モデルの学習

カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。

各エポックで、学習データをシャッフルしてデータのミニバッチについてループします。

各ミニバッチで次を行います。

  • ジェネレーター ネットワーク用の乱数値の配列を含む dlarray オブジェクトを生成。

  • GPU で学習する場合、データを gpuArray (Parallel Computing Toolbox) オブジェクトに変換。

  • dlfeval、および補助関数 modelDiscriminatorGradients および modelGeneratorGradients を使用してモデル勾配を評価。

  • 関数adamupdateを使用してネットワーク パラメーターを更新。

Adam のパラメーターを初期化します。

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

saveCheckpointstrue に設定すると、更新された重みと状態が 10 エポックごとに MAT ファイルに保存されます。学習が中断された場合、この MAT ファイルを使用して学習を再開できます。この例では、saveCheckpointsfalse に設定します。

saveCheckpoints = false;

ジェネレーターの入力の長さを指定します。

numLatentInputs = 100;

GAN に学習させます。実行には数時間かかることがあります。

iteration = 0;

for epoch = 1:maxEpochs

    % Shuffle the data.
    idx = randperm(size(STrain,4));
    STrain = STrain(:,:,:,idx);

    % Loop over mini-batches.
    for index = 1:numIterationsPerEpoch
        
        iteration = iteration + 1;

        % Read mini-batch of data.
        dlX = STrain(:,:,:,(index-1)*miniBatchSize+1:index*miniBatchSize);
        dlX = dlarray(dlX,"SSCB");
        
        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,"single") - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the discriminator gradients using dlfeval and the
        % modelDiscriminatorGradients helper function.
        gradientsDiscriminator = ...
            dlfeval(@modelDiscriminatorGradients,discriminatorParameters,generatorParameters,dlX,dlZ);
        
        % Update the discriminator network parameters.
        [discriminatorParameters,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(discriminatorParameters,gradientsDiscriminator, ...
            trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ...
            learnRateDiscriminator,gradientDecayFactor,squaredGradientDecayFactor);

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,"single") - 0.5 ) ;
        dlZ = dlarray(Z);
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the generator gradients using dlfeval and the
        % |modelGeneratorGradients| helper function.
        gradientsGenerator  = ...
            dlfeval(@modelGeneratorGradients,discriminatorParameters,generatorParameters,dlZ);
        
        % Update the generator network parameters.
        [generatorParameters,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(generatorParameters,gradientsGenerator, ...
            trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
            learnRateGenerator,gradientDecayFactor,squaredGradientDecayFactor);
    end

    % Every 10 epochs, save a training snapshot to a MAT file.
    if mod(epoch,10)==0
        disp("Epoch " + epoch + " out of " + maxEpochs + " complete.");
        if saveCheckpoints
            % Save checkpoint in case training is interrupted.
            save("audiogancheckpoint.mat", ...
                "generatorParameters","discriminatorParameters", ...
                "trailingAvgDiscriminator","trailingAvgSqDiscriminator", ...
                "trailingAvgGenerator","trailingAvgSqGenerator","iteration");
        end
    end
end
Epoch 10 out of 1000 complete.
Epoch 20 out of 1000 complete.
Epoch 30 out of 1000 complete.
Epoch 40 out of 1000 complete.
Epoch 50 out of 1000 complete.
Epoch 60 out of 1000 complete.
Epoch 70 out of 1000 complete.
Epoch 80 out of 1000 complete.
Epoch 90 out of 1000 complete.
Epoch 100 out of 1000 complete.
Epoch 110 out of 1000 complete.
Epoch 120 out of 1000 complete.
Epoch 130 out of 1000 complete.
Epoch 140 out of 1000 complete.
Epoch 150 out of 1000 complete.
Epoch 160 out of 1000 complete.
Epoch 170 out of 1000 complete.
Epoch 180 out of 1000 complete.
Epoch 190 out of 1000 complete.
Epoch 200 out of 1000 complete.
Epoch 210 out of 1000 complete.
Epoch 220 out of 1000 complete.
Epoch 230 out of 1000 complete.
Epoch 240 out of 1000 complete.
Epoch 250 out of 1000 complete.
Epoch 260 out of 1000 complete.
Epoch 270 out of 1000 complete.
Epoch 280 out of 1000 complete.
Epoch 290 out of 1000 complete.
Epoch 300 out of 1000 complete.
Epoch 310 out of 1000 complete.
Epoch 320 out of 1000 complete.
Epoch 330 out of 1000 complete.
Epoch 340 out of 1000 complete.
Epoch 350 out of 1000 complete.
Epoch 360 out of 1000 complete.
Epoch 370 out of 1000 complete.
Epoch 380 out of 1000 complete.
Epoch 390 out of 1000 complete.
Epoch 400 out of 1000 complete.
Epoch 410 out of 1000 complete.
Epoch 420 out of 1000 complete.
Epoch 430 out of 1000 complete.
Epoch 440 out of 1000 complete.
Epoch 450 out of 1000 complete.
Epoch 460 out of 1000 complete.
Epoch 470 out of 1000 complete.
Epoch 480 out of 1000 complete.
Epoch 490 out of 1000 complete.
Epoch 500 out of 1000 complete.
Epoch 510 out of 1000 complete.
Epoch 520 out of 1000 complete.
Epoch 530 out of 1000 complete.
Epoch 540 out of 1000 complete.
Epoch 550 out of 1000 complete.
Epoch 560 out of 1000 complete.
Epoch 570 out of 1000 complete.
Epoch 580 out of 1000 complete.
Epoch 590 out of 1000 complete.
Epoch 600 out of 1000 complete.
Epoch 610 out of 1000 complete.
Epoch 620 out of 1000 complete.
Epoch 630 out of 1000 complete.
Epoch 640 out of 1000 complete.
Epoch 650 out of 1000 complete.
Epoch 660 out of 1000 complete.
Epoch 670 out of 1000 complete.
Epoch 680 out of 1000 complete.
Epoch 690 out of 1000 complete.
Epoch 700 out of 1000 complete.
Epoch 710 out of 1000 complete.
Epoch 720 out of 1000 complete.
Epoch 730 out of 1000 complete.
Epoch 740 out of 1000 complete.
Epoch 750 out of 1000 complete.
Epoch 760 out of 1000 complete.
Epoch 770 out of 1000 complete.
Epoch 780 out of 1000 complete.
Epoch 790 out of 1000 complete.
Epoch 800 out of 1000 complete.
Epoch 810 out of 1000 complete.
Epoch 820 out of 1000 complete.
Epoch 830 out of 1000 complete.
Epoch 840 out of 1000 complete.
Epoch 850 out of 1000 complete.
Epoch 860 out of 1000 complete.
Epoch 870 out of 1000 complete.
Epoch 880 out of 1000 complete.
Epoch 890 out of 1000 complete.
Epoch 900 out of 1000 complete.
Epoch 910 out of 1000 complete.
Epoch 920 out of 1000 complete.
Epoch 930 out of 1000 complete.
Epoch 940 out of 1000 complete.
Epoch 950 out of 1000 complete.
Epoch 960 out of 1000 complete.
Epoch 970 out of 1000 complete.
Epoch 980 out of 1000 complete.
Epoch 990 out of 1000 complete.
Epoch 1000 out of 1000 complete.

音の合成

ネットワークの学習が完了したので、次に、合成プロセスをさらに詳しく見ていきます。

学習済みのドラムビート ジェネレーターは、乱数値の入力配列から短時間フーリエ変換 (STFT) 行列を合成します。逆 STFT (ISTFT) 演算は、時間-周波数の STFT を時間領域の合成オーディオ信号に変換します。

事前学習済みのジェネレーターの重みを読み込みます。これらの重みは、前の節で示した学習を 1000 エポック実行して得られたものです。

load(matFileName,"generatorParameters","SMean","SStd");

このジェネレーターは、乱数値の 1 x 1 x 100 のベクトルを入力として受け取ります。サンプルの入力ベクトルを生成します。

numLatentInputs = 100;
dlZ = dlarray(2*(rand(1,1,numLatentInputs,1,"single") - 0.5));

ランダム ベクトルをジェネレーターに渡し、STFT イメージを作成します。generatorParameters は、事前学習済みのジェネレーターの重みが格納された構造体です。

dlXGenerated = modelGenerator(dlZ,generatorParameters);

STFT の dlarray を単精度行列に変換します。

S = dlXGenerated.extractdata;

STFT を転置して、その次元を関数 istft に合わせます。

S = S.';

この STFT は、128 行 128 列の行列です。ここで、最初の次元は、0 から 8 kHz まで線形に配置された 128 個の周波数ビンを表します。このジェネレーターは、256 の FFT 長から片側 STFT を生成するように学習されています。この STFT では最後のビンが省略されています。STFT に 0 の行を挿入し、このビンを再導入します。

S = [S;zeros(1,128)];

学習用の STFT を生成するのに使用した正規化とスケーリングの手順を元に戻します。

S = S * 3;
S = (S.*SStd) + SMean;

STFT を、対数領域から線形領域に変換します。

S = exp(S);

STFT を、片側から両側に変換します。

S = [S;S(end-1:-1:2,:)];

ゼロでパディングし、ウィンドウのエッジの影響を除去します。

S = [zeros(256,100),S,zeros(256,100)];

STFT 行列には位相情報が含まれていません。高速版の Griffin-Lim アルゴリズムを 20 回繰り返して信号の位相を推定し、オーディオ サンプルを生成します。

myAudio = stftmag2sig(S,256, ...
    FrequencyRange="twosided", ...
    Window=hann(256,"periodic"), ...
    OverlapLength=128, ...
    MaxIterations=20, ...
    Method="fgla");
myAudio = myAudio./max(abs(myAudio),[],"all");
myAudio = myAudio(128*100:end-128*100);

合成されたドラムビートを再生します。

sound(myAudio,fs)

合成されたドラムビートをプロットします。

t = (0:length(myAudio)-1)/fs;
plot(t,myAudio)
grid on
xlabel("Time (s)")
title("Synthesized GAN Sound")

合成されたドラムビートの STFT をプロットします。

figure
stft(myAudio,fs,Window=hann(256,"periodic"),OverlapLength=128);

モデル ジェネレーター関数

関数 modelGenerator は、1 x 1 x 100 の配列 (dlX) を 128 x 128 x 1 の配列 (dlY) にアップスケールします。parameters は、ジェネレーター層の重みが格納された構造体です。ジェネレーターのアーキテクチャは [1] の表 4 で定義されています。

function dlY = modelGenerator(dlX,parameters)

dlY = fullyconnect(dlX,parameters.FC.Weights,parameters.FC.Bias,Dataformat="SSCB");

dlY = reshape(dlY,[1024 4 4 size(dlY,2)]);
dlY = permute(dlY,[3 2 1 4]);
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv1.Weights,parameters.Conv1.Bias,Stride=2,Cropping="same",DataFormat="SSCB");
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,Stride=2,Cropping="same",DataFormat="SSCB");
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,Stride=2,Cropping="same",DataFormat="SSCB");
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,Stride=2,Cropping="same",DataFormat="SSCB");
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,Stride=2,Cropping="same",DataFormat="SSCB");
dlY = tanh(dlY);
end

モデル ディスクリミネーター関数

関数 modelDiscriminator は、128 x 128 のイメージを入力として受け取り、スカラーの予測スコアを出力します。このディスクリミネーターのアーキテクチャは [1] の表 5 で定義されています。

function dlY = modelDiscriminator(dlX,parameters)

dlY = dlconv(dlX,parameters.Conv1.Weights,parameters.Conv1.Bias,Stride=2,Padding="same");
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,Stride=2,Padding="same");
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,Stride=2,Padding="same");
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,Stride=2,Padding="same");
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,Stride=2,Padding="same");
dlY = leakyrelu(dlY,0.2);
 
dlY = stripdims(dlY);
dlY = permute(dlY,[3 2 1 4]);
dlY = reshape(dlY,4*4*64*16,numel(dlY)/(4*4*64*16));

weights = parameters.FC.Weights;
bias = parameters.FC.Bias;
dlY = fullyconnect(dlY,weights,bias,Dataformat="CB");

end

ディスクリミネーター勾配関数のモデル化

関数 modelDiscriminatorGradients は、ジェネレーターおよびディスクリミネーターの generatorParameters パラメーターと discriminatorParameters パラメーター、入力データのミニバッチ X、および乱数値の配列 Z を入力として受け取り、ネットワーク内の学習可能パラメーターについてのディスクリミネーターの損失の勾配を返します。

function gradientsDiscriminator = modelDiscriminatorGradients(discriminatorParameters,generatorParameters,X,Z)

% Calculate the predictions for real data with the discriminator network.
Y = modelDiscriminator(X,discriminatorParameters);

% Calculate the predictions for generated data with the discriminator network.
Xgen = modelGenerator(Z,generatorParameters);
Ygen = modelDiscriminator(dlarray(Xgen,"SSCB"),discriminatorParameters);

% Calculate the GAN loss.
lossDiscriminator = ganDiscriminatorLoss(Y,Ygen);

% For each network, calculate the gradients with respect to the loss.
gradientsDiscriminator = dlgradient(lossDiscriminator,discriminatorParameters);

end

ジェネレーター勾配関数のモデル化

関数 modelGeneratorGradients は、ディスクリミネーターおよびジェネレーターの学習可能パラメーターと、乱数値の配列 Z を入力として受け取り、ネットワーク内の学習可能パラメーターについてのジェネレーターの損失の勾配を返します。

function gradientsGenerator = modelGeneratorGradients(discriminatorParameters,generatorParameters,Z)

% Calculate the predictions for generated data with the discriminator network.
Xgen = modelGenerator(Z,generatorParameters);
Ygen = modelDiscriminator(dlarray(Xgen,"SSCB"),discriminatorParameters);

% Calculate the GAN loss
lossGenerator = ganGeneratorLoss(Ygen);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator,generatorParameters);

end

ディスクリミネーターの損失関数

ディスクリミネーターの目的はジェネレーターに騙されないことです。ディスクリミネーターが実イメージと生成イメージを正常に区別する確率を最大化するには、ディスクリミネーターの損失関数を最小化します。ジェネレーターの損失関数は、[1] で示された DCGAN 法に従います。

function  lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated)

fake = dlarray(zeros(1,size(dlYPred,2)));
real = dlarray(ones(1,size(dlYPred,2)));

D_loss = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,fake));
D_loss = D_loss + mean(sigmoid_cross_entropy_with_logits(dlYPred,real));
lossDiscriminator  = D_loss / 2;
end

ジェネレーターの損失関数

ジェネレーターの目的はディスクリミネーターが "実データ" に分類するようなデータを生成することです。ジェネレーターが生成したイメージをディスクリミネーターが実イメージとして分類する確率を最大化するには、ジェネレーターの損失関数を最小化します。ジェネレーターの損失関数は、[1] で示された深層畳み込み敵対的生成ネットワーク (DCGAN) 法に従います。

function lossGenerator = ganGeneratorLoss(dlYPredGenerated)
real = dlarray(ones(1,size(dlYPredGenerated,2)));
lossGenerator = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,real));
end

ディスクリミネーターの重みの初期化子

initializeDiscriminatorWeights は、Glorot アルゴリズムを使用してディスクリミネーターの重みを初期化します。

function discriminatorParameters = initializeDiscriminatorWeights

filterSize = [5 5];
dim = 64;

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,dim,"single");
discriminatorParameters.Conv1.Weights = dlarray(weights);
discriminatorParameters.Conv1.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,2*dim,"single");
discriminatorParameters.Conv2.Weights = dlarray(weights);
discriminatorParameters.Conv2.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,4*dim,"single");
discriminatorParameters.Conv3.Weights = dlarray(weights);
discriminatorParameters.Conv3.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,8*dim,"single");
discriminatorParameters.Conv4.Weights = dlarray(weights);
discriminatorParameters.Conv4.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,16*dim,"single");
discriminatorParameters.Conv5.Weights = dlarray(weights);
discriminatorParameters.Conv5.Bias = dlarray(bias);

% fully connected
weights = iGlorotInitialize([1,4 * 4 * dim * 16]);
bias = zeros(1,1,"single");
discriminatorParameters.FC.Weights = dlarray(weights);
discriminatorParameters.FC.Bias = dlarray(bias);
end

ジェネレーターの重みの初期化子

initializeGeneratorWeights は、Glorot アルゴリズムを使用してジェネレーターの重みを初期化します。

function generatorParameters = initializeGeneratorWeights

dim = 64;

% Dense 1
weights = iGlorotInitialize([dim*256,100]);
bias = zeros(dim*256,1,"single");
generatorParameters.FC.Weights = dlarray(weights);
generatorParameters.FC.Bias = dlarray(bias);

filterSize = [5 5];

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,dim*8,"single");
generatorParameters.Conv1.Weights = dlarray(weights);
generatorParameters.Conv1.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,dim*4,"single");
generatorParameters.Conv2.Weights = dlarray(weights);
generatorParameters.Conv2.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,dim*2,"single");
generatorParameters.Conv3.Weights = dlarray(weights);
generatorParameters.Conv3.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,dim,"single");
generatorParameters.Conv4.Weights = dlarray(weights);
generatorParameters.Conv4.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,1,"single");
generatorParameters.Conv5.Weights = dlarray(weights);
generatorParameters.Conv5.Bias = dlarray(bias);
end

ドラムビートの合成

synthesizeDrumBeat は、事前学習済みのネットワークを使用してドラムビートを合成します。

function y = synthesizeDrumBeat

persistent pGeneratorParameters pMean pSTD
if isempty(pGeneratorParameters)
    % If the MAT file does not exist, download it
    filename = "drumGeneratorWeights.mat";
    load(filename,"SMean","SStd","generatorParameters");
    pMean = SMean;
    pSTD  = SStd;
    pGeneratorParameters = generatorParameters;
end

% Generate random vector
dlZ = dlarray(2 * ( rand(1,1,100,1,"single") - 0.5 ));

% Generate spectrograms
dlXGenerated = modelGenerator(dlZ,pGeneratorParameters);

% Convert from dlarray to single
S = dlXGenerated.extractdata;

S = S.';
% Zero-pad to remove edge effects
S = [S ; zeros(1,128)];

% Reverse steps from training
S = S * 3;
S = (S.*pSTD) + pMean;
S = exp(S);

% Make it two-sided
S = [S ; S(end-1:-1:2,:)];
% Pad with zeros at end and start
S = [zeros(256,100) S zeros(256,100)];

% Reconstruct the signal using a fast Griffin-Lim algorithm.
myAudio = stftmag2sig(S,256, ...
    FrequencyRange="twosided", ...
    Window=hann(256,"periodic"), ...
    OverlapLength=128, ...
    MaxIterations=20, ...
    Method="fgla");
myAudio = myAudio./max(abs(myAudio),[],"all");
y = myAudio(128*100:end-128*100);
end

ユーティリティ関数

function out = sigmoid_cross_entropy_with_logits(x,z)
out = max(x, 0) - x .* z + log(1 + exp(-abs(x)));
end

function w = iGlorotInitialize(sz)
if numel(sz) == 2
    numInputs = sz(2);
    numOutputs = sz(1);
else
    numInputs = prod(sz(1:3));
    numOutputs = prod(sz([1 2 4]));
end
multiplier = sqrt(2 / (numInputs + numOutputs));
w = multiplier * sqrt(3) * (2 * rand(sz,"single") - 1);
end

参考文献

[1] Donahue, C., J. McAuley, and M. Puckette. 2019."Adversarial Audio Synthesis." ICLR.