Main Content

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

敵対的生成ネットワーク (GAN) の学習

この例では、敵対的生成ネットワーク (GAN) に学習させてイメージを生成する方法を説明します。

敵対的生成ネットワーク (GAN) は深層学習ネットワークの一種で、入力された実データに類似した特性をもつデータを生成できます。

GAN は一緒に学習を行う 2 つのネットワークで構成されています。

  1. ジェネレーター — このネットワークは、乱数値 (潜在入力) のベクトルを入力として与えられ、学習データと同じ構造のデータを生成します。

  2. ディスクリミネーター — このネットワークは、学習データとジェネレーターにより生成されたデータの両方からの観測値を含むデータのバッチを与えられ、その観測値が "実データ" か "生成データ" かの分類を試みます。

GAN に学習させる場合は、両方のネットワークの学習を同時に行うことで両者の性能を最大化します。

  • ジェネレーターに学習させて、ディスクリミネーターを "騙す" データを生成。

  • ディスクリミネーターに学習させて、実データと生成データを区別。

ジェネレーターの性能を最適化するには、生成データが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが "実データ" と分類するようなデータを生成することです。

ディスクリミネーターの性能を最適化するには、実データと生成データの両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。つまり、ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。

これらの方法によって、十分に現実的なデータを生成するジェネレーターと、学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。

学習データの読み込み

Flowers のデータセット [1] をダウンロードし、解凍します。

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

花の写真のイメージ データストアを作成します。

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

データを拡張して水平方向にランダムに反転させ、イメージのサイズを 64 x 64 に変更します。

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

ジェネレーター ネットワークの定義

乱数値の 1 x 1 x 100 の配列からイメージを生成するネットワーク アーキテクチャを以下のように定義します。

このネットワークは、次を行います。

  • "投影形状変更" 層を使用して、ノイズで構成される 1 x 1 x 100 の配列を 7 x 7 x 128 の配列に変換。

  • バッチ正規化と ReLU 層を用いた一連の転置畳み込み層を使用して、結果の配列を 64 x 64 x 3 の配列にスケール アップ。

このネットワーク アーキテクチャを層グラフとして定義し、次のネットワーク プロパティを指定します。

  • 転置畳み込み層では、5 x 5 のフィルターを指定し、各層でフィルター数を減らし、ストライドを 2 にし、各エッジの出力をトリミングするように設定。

  • 最後の転置畳み込み層では、生成されたイメージの 3 つの RGB チャネルに対応する 3 つの 5 x 5 のフィルターと、前の層の出力サイズを設定。

  • ネットワークの最後に、tanh 層を追加。

ノイズ入力を投影して形状変更するには、この例にサポート ファイルとして添付されている、カスタム層 projectAndReshapeLayer を使用します。projectAndReshapeLayer 層は、全結合演算を使用して入力をスケール アップし、出力を指定サイズに形状変更します。

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bnorm1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bnorm2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bnorm3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。

dlnetGenerator = dlnetwork(lgraphGenerator);

ディスクリミネーター ネットワークの定義

64 x 64 の実イメージと生成イメージを分類する、次のネットワークを定義します。

64 x 64 x 3 のイメージを受け取り、バッチ正規化と Leaky ReLU 層のある一連の畳み込み層を使用してスカラーの予測スコアを返すネットワークを作成します。ドロップアウトを使用して、入力イメージにノイズを追加します。

  • ドロップアウト層で、ドロップアウトの確率を 0.5 に設定。

  • 畳み込み層で、5 x 5 のフィルターを指定し、各層でフィルター数を増やす。また、ストライド 2 で出力をパディングするように指定。

  • Leaky ReLU 層で、スケールを 0.2 に設定。

  • 最後の層で、4 x 4 のフィルターを 1 つもつ畳み込み層を設定。

範囲 [0,1] の確率を出力するには、モデル勾配関数の関数 sigmoid を使用します

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    dropoutLayer(0.5,'Name','dropout')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。

dlnetDiscriminator = dlnetwork(lgraphDiscriminator);

モデル勾配、損失関数、およびスコアの定義

例のモデル勾配関数の節にリストされている関数 modelGradients を作成します。この関数は、ジェネレーターおよびディスクリミネーター ネットワーク、入力データのミニバッチ、および乱数値と反転係数の配列を入力として受け取り、ネットワーク内の学習可能なパラメーターについての損失の勾配と、2 つのネットワークのスコアを返します。

学習オプションの指定

ミニバッチ サイズを 128 として 500 エポック学習させます。また、拡張イメージ データストアの読み取りサイズをミニバッチのサイズに設定します。大きなデータセットでは、学習させるエポック数をこれより少なくできる場合があります。

numEpochs = 500;
miniBatchSize = 128;
augimds.MiniBatchSize = miniBatchSize;

Adam 最適化のオプションを指定します。両方のネットワークで次のように設定します。

  • 学習率 0.0002

  • 勾配減衰係数 0.5

  • 2 乗勾配減衰係数 0.999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

executionEnvironment = "auto";

実イメージと生成イメージとを区別するディスクリミネーターの学習速度が速すぎる場合、ジェネレーターの学習に失敗する可能性があります。ディスクリミネーターとジェネレーターの学習バランスを改善するために、ラベルをランダムに反転させて実データにノイズを加えます。

実ラベルの 30% を反転するように指定します。これは、ラベルの総数の 15% が反転することを意味します。生成されたイメージにはすべて正しいラベルが付いているので、これがジェネレーターに損失を与えることはない点に注意してください。

flipFactor = 0.3;

生成された検証イメージを 100 回の反復ごとに表示します。

validationFrequency = 100;

モデルの学習

カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。学習の進行状況を監視するには、ホールドアウトされた乱数値の配列をジェネレーターに入力して得られた生成イメージのバッチと、スコアのプロットを表示します。

Adam のパラメーターを初期化します。

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

学習の進行状況を監視するには、ホールドアウトされた乱数値の固定配列のバッチをジェネレーターに渡して得られた生成イメージのバッチを表示し、ネットワークのスコアをプロットします。

ホールドアウトされた乱数値の配列を作成します。

numValidationImages = 25;
ZValidation = randn(1,1,numLatentInputs,numValidationImages,'single');

データを dlarray オブジェクトに変換し、次元ラベル 'SSCB' (spatial、spatial、channel、batch) を指定します。

dlZValidation = dlarray(ZValidation,'SSCB');

GPU での学習のために、データを gpuArray オブジェクトに変換します。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

学習の進行状況プロットを初期化します。Figure を作成して幅が 2 倍になるようサイズ変更します。

f = figure;
f.Position(3) = 2*f.Position(3);

生成イメージとネットワーク スコアのサブプロットを作成します。

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

スコアのプロット用にアニメーションの線を初期化します。

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);
legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

GAN に学習させます。各エポックで、データストアをシャッフルしてデータのミニバッチについてループします。

各ミニバッチで次を行います。

  • イメージを範囲 [-1 1] で再スケーリング。

  • 基となる型が singledlarray オブジェクトにデータを変換し、次元ラベルを 'SSCB' (spatial、spatial、channel、batch) に指定。

  • ジェネレーター ネットワーク用の乱数値の配列を含む dlarray オブジェクトを生成。

  • GPU での学習のために、データを gpuArray オブジェクトに変換。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配を評価します。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

  • 2 つのネットワークのスコアをプロット。

  • validationFrequency の反復がすべて終了した後で、ホールドアウトされた固定ジェネレーター入力の生成イメージのバッチを表示。

学習を行うのに時間がかかる場合があります。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Concatenate mini-batch of data and generate latent inputs for the
        % generator network.
        X = cat(4,data{:,1}{:});
        X = single(X);
        Z = randn(1,1,numLatentInputs,size(X,4),'single');
        
        % Rescale the images in the range [-1 1].
        X = rescale(X,-1,1,'InputMin',0,'InputMax',255);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        dlZ = dlarray(Z, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation));
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

ここでは、ディスクリミネーターは生成イメージの中から実イメージを識別する強い特徴表現を学習しました。それに対し、ジェネレーターは実データのように見えるデータを生成できるように、同様に強い特徴表現を学習しました。

学習プロットは、ジェネレーターおよびディスクリミネーターのネットワークのスコアを示しています。ネットワークのスコアを解釈する方法の詳細については、GAN の学習過程の監視と一般的な故障モードの識別を参照してください。

新しいイメージの生成

新しいイメージを生成するには、ジェネレーターに対して関数 predict を使用して、乱数値の 1 x 1 x 100 の配列のバッチを含む dlarray オブジェクトを指定します。イメージを並べて表示するには関数 imtile を使用し、関数 rescale を使ってイメージを再スケーリングします。

乱数値の 1 x 1 x 100 の配列 25 個のバッチを含む dlarray オブジェクトを作成します。

ZNew = randn(1,1,numLatentInputs,25,'single');
dlZNew = dlarray(ZNew,'SSCB');

GPU を使用してイメージを生成するには、データを gpuArray オブジェクトにも変換します。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

関数 predict をジェネレーターと入力データと共に使用して、新しいイメージを生成します。

dlXGeneratedNew = predict(dlnetGenerator,dlZNew);

イメージを表示します。

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

モデル勾配関数

関数 modelGradients は、ジェネレーターおよびディスクリミネーターの dlnetwork オブジェクトである dlnetGeneratordlnetDiscriminator、入力データのミニバッチ dlX、乱数値の配列 dlZ 、および実ラベルの反転する割合 flipFactor, を入力として受け取り、ネットワーク内の学習可能なパラメーターについての損失の勾配、ジェネレーターの状態、および 2 つのネットワークのスコアを返します。ディスクリミネーターの出力は範囲 [0,1] に含まれないため、modelGradients はシグモイド関数を適用してこれを確率に変換します。

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

% Convert the discriminator outputs to probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the score of the discriminator.
scoreDiscriminator = ((mean(probReal)+mean(1-probGenerated))/2);

% Calculate the score of the generator.
scoreGenerator = mean(probGenerated);

% Randomly flip a fraction of the labels of the real images.
numObservations = size(probReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));

% Flip the labels
probReal(:,:,:,idx) = 1-probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

GAN の損失関数とスコア

ジェネレーターの目的はディスクリミネーターが "実データ" に分類するようなデータを生成することです。ジェネレーターが生成したイメージをディスクリミネーターが実データとして分類する確率を最大化するには、負の対数尤度関数を最小化します。

ディスクリミネーターの出力 Y が与えられた場合、次のようになります。

  • Yˆ=σ(Y) は、入力イメージが "実" クラスに属する確率です。

  • 1-Yˆ は、入力イメージが "生成" クラスに属している確率です。

シグモイド演算 σ は関数 modelGradients で行われる点に注意してください。ジェネレーターの損失関数は次の式で表されます。

lossGenerator=-mean(log(YˆGenerated)),

ここで、YˆGenerated は生成イメージに対するディスクリミネーターの出力確率を表しています。

ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。ディスクリミネーターが実イメージと生成イメージを正しく区別する確率を最大化するには、対応する負の対数尤度関数の和を最小化します。

ディスクリミネーターの損失関数は次の式で表されます。

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

ここで、YˆReal は実イメージに対するディスクリミネーターの出力確率を表しています。

ジェネレーターとディスクリミネーターがそれぞれの目標をどれだけ達成するかを 0 から 1 のスケールで測定するには、スコアの概念を使用できます。

ジェネレーターのスコアは、生成イメージに対するディスクリミネーターの出力に対応する確率の平均です。

scoreGenerator=mean(YˆGenerated).

ディスクリミネーターのスコアは、実イメージと生成イメージの両方に対するディスクリミネーターの出力に対応する確率の平均です。

scoreDiscriminator=12mean(YˆReal)+12mean(1-YˆGenerated).

スコアは損失に反比例しますが、実質的には同じ情報を表しています。

function [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated)

% Calculate the loss for the discriminator network.
lossDiscriminator =  -mean(log(probReal)) -mean(log(1-probGenerated));

% Calculate the loss for the generator network.
lossGenerator = -mean(log(probGenerated));

end

参考文献

  1. The TensorFlow Team.Flowers http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Radford, Alec, Luke Metz, and Soumith Chintala."Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).

参考

| | | | | |

関連するトピック