Main Content

敵対的生成ネットワーク (GAN) の学習

この例では、敵対的生成ネットワークに学習させてイメージを生成する方法を説明します。

敵対的生成ネットワーク (GAN) は深層学習ネットワークの一種で、入力された実データに類似した特性をもつデータを生成できます。

関数 trainNetwork は GAN の学習をサポートしていないため、カスタム学習ループを実装しなければなりません。カスタム学習ループを使用して GAN に学習させるには、自動微分のために dlarray オブジェクトと dlnetwork オブジェクトを使用できます。

GAN は一緒に学習を行う 2 つのネットワークで構成されています。

  1. ジェネレーター — このネットワークは、乱数値 (潜在入力) のベクトルを入力として与えられ、学習データと同じ構造のデータを生成します。

  2. ディスクリミネーター。学習データとジェネレーターにより生成されたデータの両方の観測値を含むデータのバッチを与えられ、その観測値が "real""generated" かの分類を試みます。

次の図は、ランダム入力のベクトルからイメージを生成する GAN のジェネレーター ネットワークを示しています。

次の図は、GAN の構造を示しています。

GAN に学習させる場合は、両方のネットワークの学習を同時に行うことで両者の性能を最大化します。

  • ジェネレーターに学習させて、ディスクリミネーターを "騙す" データを生成。

  • ディスクリミネーターに学習させて、実データと生成データを区別。

ジェネレーターの性能を最適化するために、生成データが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが "real" と分類するようなデータを生成することです。

ディスクリミネーターの性能を最適化するために、実データと生成データの両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。つまり、ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。

これらの方法によって、いかにも本物らしいデータを生成するジェネレーターと、学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。

この例では、花のイメージを含む Flowers データ セット [1] を使用してイメージを生成するように GAN に学習させます。

学習データの読み込み

Flowers のデータセット [1] をダウンロードし、解凍します。

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flowers data set (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

花の写真のイメージ データストアを作成します。

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder,IncludeSubfolders=true);

データを拡張して水平方向にランダムに反転させ、イメージのサイズを 64 x 64 に変更します。

augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);

敵対的生成ネットワークの定義

GAN は一緒に学習を行う 2 つのネットワークで構成されています。

  1. ジェネレーター — このネットワークは、乱数値 (潜在入力) のベクトルを入力として与えられ、学習データと同じ構造のデータを生成します。

  2. ディスクリミネーター。学習データとジェネレーターにより生成されたデータの両方の観測値を含むデータのバッチを与えられ、その観測値が "real""generated" かの分類を試みます。

次の図は、GAN の構造を示しています。

ジェネレーター ネットワークの定義

ランダム ベクトルからイメージを生成する以下のネットワーク アーキテクチャを定義します。

このネットワークは、次を行います。

  • 投影および形状変更操作を使用して、サイズ 100 のランダム ベクトルを 4×4×512 の配列に変換。

  • バッチ正規化と ReLU 層を用いた一連の転置畳み込み層を使用して、結果の配列を 64 x 64 x 3 の配列にアップスケール。

このネットワーク アーキテクチャを層グラフとして定義し、次のネットワーク プロパティを指定します。

  • 転置畳み込み層では、各層のフィルター数を減らして 5 x 5 のフィルターを指定し、ストライドを 2 に指定し、各エッジでの出力のトリミングを指定します。

  • 最後の転置畳み込み層では、生成されたイメージの 3 つの RGB チャネルに対応する 3 つの 5 x 5 のフィルターと、前の層の出力サイズを設定。

  • ネットワークの最後に、tanh 層を追加。

ノイズ入力を投影して形状変更するには、この例にサポート ファイルとして添付されている、カスタム層 projectAndReshapeLayer を使用します。この層にアクセスするには、この例をライブ スクリプトとして開きます。

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    featureInputLayer(numLatentInputs)
    projectAndReshapeLayer(projectionSize,numLatentInputs)
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。

netG = dlnetwork(layersGenerator);

ディスクリミネーター ネットワークの定義

64 x 64 の実イメージと生成イメージを分類する、次のネットワークを定義します。

64 x 64 x 3 のイメージを受け取り、バッチ正規化と leaky ReLU 層のある一連の畳み込み層を使用してスカラーの予測スコアを返すネットワークを作成します。ドロップアウトを使用して、入力イメージにノイズを追加します。

  • ドロップアウト層で、ドロップアウトの確率を 0.5 に設定。

  • 畳み込み層で、5 x 5 のフィルターを指定し、各層でフィルター数を増やす。また、ストライドを 2 に指定し、出力のパディングを指定します。

  • leaky ReLU 層で、スケールを 0.2 に設定。

  • [0,1] の範囲の確率を出力するために、4 行 4 列のフィルターを 1 つもつ畳み込み層に続けてシグモイド層を指定。

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)
    sigmoidLayer];

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。

netD = dlnetwork(layersDiscriminator);

モデル損失関数の定義

この例のモデル損失関数の節にリストされている関数 modelLoss を作成します。この関数は、ジェネレーターおよびディスクリミネーターのネットワーク、入力データのミニバッチ、および乱数値と反転係数の配列を入力として受け取り、ネットワーク内の学習可能パラメーターについての損失値と損失値の勾配、ジェネレーターの状態、および 2 つのネットワークのスコアを返します。

学習オプションの指定

ミニバッチ サイズを 128 として 500 エポック学習させます。大きなデータ セットでは、学習させるエポック数をこれより少なくできる場合があります。

numEpochs = 500;
miniBatchSize = 128;

Adam 最適化のオプションを指定します。両方のネットワークで次のように設定します。

  • 学習率 0.0002

  • 勾配の減衰係数 0.5

  • 2 乗勾配の減衰係数 0.999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

実イメージと生成イメージとを区別するディスクリミネーターの学習速度が速すぎる場合、ジェネレーターの学習に失敗する可能性があります。ディスクリミネーターとジェネレーターの学習バランスを改善するために、実イメージに割り当てられているラベルをランダムに反転させて実データにノイズを加えます。

0.35 の確率で実ラベルを反転するように設定します。生成されたイメージにはすべて正しいラベルが付いているので、これがジェネレーターに損失を与えることはない点に注意してください。

flipProb = 0.35;

生成された検証イメージを 100 回の反復ごとに表示します。

validationFrequency = 100;

モデルの学習

GAN に学習させる場合は、両方のネットワークの学習を同時に行うことで両者の性能を最大化します。

  • ジェネレーターに学習させて、ディスクリミネーターを "騙す" データを生成。

  • ディスクリミネーターに学習させて、実データと生成データを区別。

ジェネレーターの性能を最適化するために、生成データが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが "real" と分類するようなデータを生成することです。

ディスクリミネーターの性能を最適化するために、実データと生成データの両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。つまり、ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。

これらの方法によって、いかにも本物らしいデータを生成するジェネレーターと、学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。

minibatchqueueを使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、イメージを範囲 [-1,1] で再スケーリングします。

  • 指定したミニバッチ サイズ未満の観測値をもつ部分的なミニバッチは破棄します。

  • イメージ データを "SSCB" (spatial、spatial、channel、batch) 形式で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

augimds.MiniBatchSize = miniBatchSize;

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat="SSCB");

カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。学習の進行状況を監視するには、ホールドアウトされた乱数値の配列をジェネレーターに入力して得られた生成イメージのバッチと、スコアのプロットを表示します。

Adam 最適化のパラメーターを初期化します。

trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];

学習の進行状況を監視するには、ホールドアウトされたランダム固定ベクトルのバッチをジェネレーターに渡して得られた生成イメージのバッチを表示し、ネットワークのスコアをプロットします。

ホールドアウトされた乱数値の配列を作成します。

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");

データを dlarray オブジェクトに変換し、形式 "CB" (channel、batch) を指定します。

ZValidation = dlarray(ZValidation,"CB");

GPU での学習のために、データを gpuArray オブジェクトに変換します。

if canUseGPU
    ZValidation = gpuArray(ZValidation);
end

学習の進行状況プロットを初期化します。Figure を作成して幅が 2 倍になるようサイズ変更します。

f = figure;
f.Position(3) = 2*f.Position(3);

生成イメージとネットワーク スコアのサブプロットを作成します。

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

スコアのプロット用にアニメーションの線を初期化します。

C = colororder;
lineScoreG = animatedline(scoreAxes,Color=C(1,:));
lineScoreD = animatedline(scoreAxes,Color=C(2,:));
legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

GAN に学習させます。各エポックで、データストアをシャッフルしてデータのミニバッチについてループします。

各ミニバッチで次を行います。

  • 関数 dlfeval と関数 modelLoss を使用して、学習可能なパラメーターについての損失の勾配、ジェネレーターの状態、およびネットワークのスコアを評価。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

  • 2 つのネットワークのスコアをプロット。

  • validationFrequency 回の反復が終わるごとに、ホールドアウトされた固定ジェネレーター入力の生成イメージのバッチを表示。

学習を行うのに時間がかかる場合があります。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Reset and shuffle datastore.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        X = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the format "CB" (channel, batch). If a GPU is
        % available, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        Z = dlarray(Z,"CB");

        if canUseGPU
            Z = gpuArray(Z);
        end

        % Evaluate the gradients of the loss with respect to the learnable
        % parameters, the generator state, and the network scores using
        % dlfeval and the modelLoss function.
        [~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
            dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
        netG.State = stateG;

        % Update the discriminator network parameters.
        [netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
            trailingAvg, trailingAvgSqD, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Update the generator network parameters.
        [netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
            trailingAvgG, trailingAvgSqG, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Every validationFrequency iterations, display batch of generated
        % images using the held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            XGeneratedValidation = predict(netG,ZValidation);

            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(XGeneratedValidation));
            I = rescale(I);

            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end

        % Update the scores plot.
        subplot(1,2,2)
        scoreG = double(extractdata(scoreG));
        addpoints(lineScoreG,iteration,scoreG);

        scoreD = double(extractdata(scoreD));
        addpoints(lineScoreD,iteration,scoreD);

        % Update the title with training progress information.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))

        drawnow
    end
end

ここでは、ディスクリミネーターは生成イメージの中から実イメージを識別する強い特徴表現を学習しました。次に、ジェネレーターは、学習データと同様のイメージを生成できるように、同様に強い特徴表現を学習しました。

学習プロットは、ジェネレーターおよびディスクリミネーターのネットワークのスコアを示しています。ネットワークのスコアを解釈する方法の詳細については、GAN の学習過程の監視と一般的な故障モードの識別を参照してください。

新しいイメージの生成

新しいイメージを生成するには、ジェネレーターに対して関数 predict を使用して、ランダム ベクトルのバッチを含む dlarray オブジェクトを指定します。イメージを並べて表示するには関数 imtile を使用し、関数 rescale を使ってイメージを再スケーリングします。

ジェネレーター ネットワークに入力するランダム ベクトル 25 個のバッチを含む dlarray オブジェクトを作成します。

numObservations = 25;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");

GPU が利用可能な場合は、潜在ベクトルを gpuArray に変換します。

if canUseGPU
    ZNew = gpuArray(ZNew);
end

関数 predict をジェネレーターと入力データと共に使用して、新しいイメージを生成します。

XGeneratedNew = predict(netG,ZNew);

イメージを表示します。

I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

モデル損失関数

関数 modelLoss は、ジェネレーターおよびディスクリミネーターの dlnetwork オブジェクトである netGnetD、入力データのミニバッチ X、乱数値の配列 Z、および実ラベルの反転する確率 flipProb を入力として受け取り、ネットワーク内の学習可能パラメーターについての損失値と損失値の勾配、ジェネレーターの状態、および 2 つのネットワークのスコアを返します。

function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
    modelLoss(netG,netD,X,Z,flipProb)

% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X);

% Calculate the predictions for generated data with the discriminator
% network.
[XGenerated,stateG] = forward(netG,Z);
YGenerated = forward(netD,XGenerated);

% Calculate the score of the discriminator.
scoreD = (mean(YReal) + mean(1-YGenerated)) / 2;

% Calculate the score of the generator.
scoreG = mean(YGenerated);

% Randomly flip the labels of the real images.
numObservations = size(YReal,4);
idx = rand(1,numObservations) < flipProb;
YReal(:,:,:,idx) = 1 - YReal(:,:,:,idx);

% Calculate the GAN loss.
[lossG, lossD] = ganLoss(YReal,YGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);

end

GAN の損失関数とスコア

ジェネレーターの目的はディスクリミネーターが "real" と分類するようなデータを生成することです。ジェネレーターが生成したイメージをディスクリミネーターが実データとして分類する確率を最大化するには、負の対数尤度関数を最小化します。

ディスクリミネーターの出力 Y が与えられた場合、次のようになります。

  • Y は、入力イメージがクラス "real" に属する確率です。

  • 1-Y は、入力イメージがクラス "generated" に属する確率です。

ジェネレーターの損失関数は次の式で表されます。

lossGenerator=-mean(log(YGenerated)),

ここで、YGenerated は生成イメージに対するディスクリミネーターの出力確率を表しています。

ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。ディスクリミネーターが実イメージと生成イメージを正しく区別する確率を最大化するには、対応する負の対数尤度関数の和を最小化します。

ディスクリミネーターの損失関数は次の式で表されます。

lossDiscriminator=-mean(log(YReal))-mean(log(1-YGenerated)),

ここで、YReal は実イメージに対するディスクリミネーターの出力確率を表しています。

ジェネレーターとディスクリミネーターがそれぞれの目標をどれだけ達成するかを 0 から 1 のスケールで測定するには、スコアの概念を使用できます。

ジェネレーターのスコアは、生成イメージに対するディスクリミネーターの出力に対応する確率の平均です。

scoreGenerator=mean(YGenerated).

ディスクリミネーターのスコアは、実イメージと生成イメージの両方に対するディスクリミネーターの出力に対応する確率の平均です。

scoreDiscriminator=12mean(YReal)+12mean(1-YGenerated).

スコアは損失に反比例しますが、実質的には同じ情報を表しています。

function [lossG,lossD] = ganLoss(YReal,YGenerated)

% Calculate the loss for the discriminator network.
lossD = -mean(log(YReal)) - mean(log(1-YGenerated));

% Calculate the loss for the generator network.
lossG = -mean(log(YGenerated));

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。

  2. イメージの範囲が [-1,1] となるように再スケーリングします。

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);

end

参考文献

  1. The TensorFlow Team.Flowers http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Radford, Alec, Luke Metz, and Soumith Chintala."Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks." Preprint, submitted November 19, 2015. http://arxiv.org/abs/1511.06434.

参考

| | | | | | |

関連するトピック