敵対的生成ネットワーク (GAN) の学習
この例では、敵対的生成ネットワークに学習させてイメージを生成する方法を説明します。
敵対的生成ネットワーク (GAN) は深層学習ネットワークの一種で、入力された実データに類似した特性をもつデータを生成できます。
関数 trainnet
は GAN の学習をサポートしていないため、カスタム学習ループを実装しなければなりません。カスタム学習ループを使用して GAN に学習させるには、自動微分のために dlarray
オブジェクトと dlnetwork
オブジェクトを使用できます。
GAN は一緒に学習を行う 2 つのネットワークで構成されています。
ジェネレーター — このネットワークは、乱数値 (潜在入力) のベクトルを入力として与えられ、学習データと同じ構造のデータを生成します。
ディスクリミネーター。学習データとジェネレーターにより生成されたデータの両方の観測値を含むデータのバッチを与えられ、その観測値が
"real"
か"generated"
かの分類を試みます。
次の図は、ランダム入力のベクトルからイメージを生成する GAN のジェネレーター ネットワークを示しています。
次の図は、GAN の構造を示しています。
GAN に学習させる場合は、両方のネットワークの学習を同時に行うことで両者の性能を最大化します。
ジェネレーターに学習させて、ディスクリミネーターを "騙す" データを生成。
ディスクリミネーターに学習させて、実データと生成データを区別。
ジェネレーターの性能を最適化するために、生成データが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが "real"
と分類するようなデータを生成することです。
ディスクリミネーターの性能を最適化するために、実データと生成データの両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。つまり、ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。
これらの方法によって、いかにも本物らしいデータを生成するジェネレーターと、学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。
この例では、花のイメージを含む Flowers データ セット [1] を使用してイメージを生成するように GAN に学習させます。
学習データの読み込み
Flowers データ セット [1] をダウンロードし、解凍します。
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); imageFolder = fullfile(downloadFolder,"flower_photos"); if ~datasetExists(imageFolder) disp("Downloading Flowers data set (218 MB)...") websave(filename,url); untar(filename,downloadFolder) end
花の写真のイメージ データストアを作成します。
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
データを拡張して水平方向にランダムに反転させ、イメージのサイズを 64 x 64 に変更します。
augmenter = imageDataAugmenter(RandXReflection=true); augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
敵対的生成ネットワークの定義
GAN は一緒に学習を行う 2 つのネットワークで構成されています。
ジェネレーター — このネットワークは、乱数値 (潜在入力) のベクトルを入力として与えられ、学習データと同じ構造のデータを生成します。
ディスクリミネーター。学習データとジェネレーターにより生成されたデータの両方の観測値を含むデータのバッチを与えられ、その観測値が
"real"
か"generated"
かの分類を試みます。
次の図は、GAN の構造を示しています。
ジェネレーター ネットワークの定義
ランダム ベクトルからイメージを生成する以下のネットワーク アーキテクチャを定義します。
このネットワークは、次を行います。
投影および形状変更操作を使用して、サイズ 100 のランダム ベクトルを 4×4×512 の配列に変換。
バッチ正規化と ReLU 層を用いた一連の転置畳み込み層を使用して、結果の配列を 64 x 64 x 3 の配列にアップスケール。
このネットワーク アーキテクチャを層グラフとして定義し、次のネットワーク プロパティを指定します。
転置畳み込み層では、各層のフィルター数を減らして 5 x 5 のフィルターを指定し、ストライドを 2 に指定し、各エッジでの出力のトリミングを指定します。
最後の転置畳み込み層では、生成されたイメージの 3 つの RGB チャネルに対応する 3 つの 5 x 5 のフィルターと、前の層の出力サイズを設定。
ネットワークの最後に、tanh 層を追加。
ノイズ入力を投影して形状変更するには、この例にサポート ファイルとして添付されている、カスタム層 projectAndReshapeLayer
を使用します。この層にアクセスするには、この例をライブ スクリプトとして開きます。
filterSize = 5; numFilters = 64; numLatentInputs = 100; projectionSize = [4 4 512]; layersGenerator = [ featureInputLayer(numLatentInputs) projectAndReshapeLayer(projectionSize) transposedConv2dLayer(filterSize,4*numFilters) batchNormalizationLayer reluLayer transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same") batchNormalizationLayer reluLayer transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same") batchNormalizationLayer reluLayer transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same") tanhLayer];
カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork
オブジェクトに変換します。
netG = dlnetwork(layersGenerator);
ディスクリミネーター ネットワークの定義
64 x 64 の実イメージと生成イメージを分類する、次のネットワークを定義します。
64 x 64 x 3 のイメージを受け取り、バッチ正規化と leaky ReLU 層のある一連の畳み込み層を使用してスカラーの予測スコアを返すネットワークを作成します。ドロップアウトを使用して、入力イメージにノイズを追加します。
ドロップアウト層で、ドロップアウトの確率を 0.5 に設定。
畳み込み層で、5 x 5 のフィルターを指定し、各層でフィルター数を増やす。また、ストライドを 2 に指定し、出力のパディングを指定します。
leaky ReLU 層で、スケールを 0.2 に設定。
[0,1] の範囲の確率を出力するために、4 行 4 列のフィルターを 1 つもつ畳み込み層に続けてシグモイド層を指定。
dropoutProb = 0.5; numFilters = 64; scale = 0.2; inputSize = [64 64 3]; filterSize = 5; layersDiscriminator = [ imageInputLayer(inputSize,Normalization="none") dropoutLayer(dropoutProb) convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same") leakyReluLayer(scale) convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same") batchNormalizationLayer leakyReluLayer(scale) convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same") batchNormalizationLayer leakyReluLayer(scale) convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same") batchNormalizationLayer leakyReluLayer(scale) convolution2dLayer(4,1) sigmoidLayer];
カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork
オブジェクトに変換します。
netD = dlnetwork(layersDiscriminator);
モデル損失関数の定義
この例のモデル損失関数の節にリストされている関数 modelLoss
を作成します。この関数は、ジェネレーターおよびディスクリミネーターのネットワーク、入力データのミニバッチ、および乱数値と反転係数の配列を入力として受け取り、ネットワーク内の学習可能パラメーターについての損失値と損失値の勾配、ジェネレーターの状態、および 2 つのネットワークのスコアを返します。
学習オプションの指定
ミニバッチ サイズを 128 として 500 エポック学習させます。大きなデータ セットでは、学習させるエポック数をこれより少なくできる場合があります。
numEpochs = 500; miniBatchSize = 128;
Adam 最適化のオプションを指定します。両方のネットワークで次のように設定します。
学習率 0.0002
勾配の減衰係数 0.5
2 乗勾配の減衰係数 0.999
learnRate = 0.0002; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
実イメージと生成イメージとを区別するディスクリミネーターの学習速度が速すぎる場合、ジェネレーターの学習に失敗する可能性があります。ディスクリミネーターとジェネレーターの学習バランスを改善するために、実イメージに割り当てられているラベルをランダムに反転させて実データにノイズを加えます。
0.35 の確率で実ラベルを反転するように設定します。生成されたイメージにはすべて正しいラベルが付いているので、これがジェネレーターに損失を与えることはない点に注意してください。
flipProb = 0.35;
生成された検証イメージを 100 回の反復ごとに表示します。
validationFrequency = 100;
モデルの学習
GAN に学習させる場合は、両方のネットワークの学習を同時に行うことで両者の性能を最大化します。
ジェネレーターに学習させて、ディスクリミネーターを "騙す" データを生成。
ディスクリミネーターに学習させて、実データと生成データを区別。
ジェネレーターの性能を最適化するために、生成データが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが "real"
と分類するようなデータを生成することです。
ディスクリミネーターの性能を最適化するために、実データと生成データの両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。つまり、ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。
これらの方法によって、いかにも本物らしいデータを生成するジェネレーターと、学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。
minibatchqueue
を使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、イメージを範囲[-1,1]
で再スケーリングします。指定したミニバッチ サイズ未満の観測値をもつ部分的なミニバッチは破棄します。
イメージ データを形式
"SSCB"
(spatial、spatial、channel、batch) で形式を整えます。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、各出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
augimds.MiniBatchSize = miniBatchSize; mbq = minibatchqueue(augimds, ... MiniBatchSize=miniBatchSize, ... PartialMiniBatch="discard", ... MiniBatchFcn=@preprocessMiniBatch, ... MiniBatchFormat="SSCB");
カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。学習の進行状況を監視するには、ホールドアウトされた乱数値の配列をジェネレーターに入力して得られた生成イメージのバッチと、スコアのプロットを表示します。
Adam 最適化のパラメーターを初期化します。
trailingAvgG = []; trailingAvgSqG = []; trailingAvg = []; trailingAvgSqD = [];
学習の進行状況を監視するには、ホールドアウトされたランダム固定ベクトルのバッチをジェネレーターに渡して得られた生成イメージのバッチを表示し、ネットワークのスコアをプロットします。
ホールドアウトされた乱数値の配列を作成します。
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
データを dlarray
オブジェクトに変換し、形式 "CB"
(channel、batch) を指定します。
ZValidation = dlarray(ZValidation,"CB");
GPU での学習用に、データを gpuArray
オブジェクトに変換します。
if canUseGPU ZValidation = gpuArray(ZValidation); end
ジェネレーターとディスクリミネーターのスコアを追跡するには、trainingProgressMonitor
オブジェクトを使用します。モニター用に合計反復回数を計算します。
numObservationsTrain = numel(imds.Files); numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize); numIterations = numEpochs*numIterationsPerEpoch;
TrainingProgressMonitor
オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor( ... Metrics=["GeneratorScore","DiscriminatorScore"], ... Info=["Epoch","Iteration"], ... XLabel="Iteration"); groupSubPlot(monitor,Score=["GeneratorScore","DiscriminatorScore"])
GAN に学習させます。各エポックで、データストアをシャッフルしてデータのミニバッチについてループします。
各ミニバッチで次を行います。
TrainingProgressMonitor
オブジェクトのStop
プロパティがtrue
のときに停止します。[停止] ボタンをクリックしたときにStop
プロパティがtrue
に変更します。関数
dlfeval
と関数modelLoss
を使用して、学習可能なパラメーターについての損失の勾配、ジェネレーターの状態、およびネットワークのスコアを評価します。関数
adamupdate
を使用してネットワーク パラメーターを更新します。2 つのネットワークのスコアをプロットします。
validationFrequency
回の反復が終わるごとに、ホールドアウトされた固定ジェネレーター入力の生成イメージのバッチを表示します。
学習を行うのに時間がかかる場合があります。
epoch = 0; iteration = 0; % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Reset and shuffle datastore. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; % Read mini-batch of data. X = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the format "CB" (channel, batch). If a GPU is % available, then convert latent inputs to gpuArray. Z = randn(numLatentInputs,miniBatchSize,"single"); Z = dlarray(Z,"CB"); if canUseGPU Z = gpuArray(Z); end % Evaluate the gradients of the loss with respect to the learnable % parameters, the generator state, and the network scores using % dlfeval and the modelLoss function. [~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ... dlfeval(@modelLoss,netG,netD,X,Z,flipProb); netG.State = stateG; % Update the discriminator network parameters. [netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ... trailingAvg, trailingAvgSqD, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Update the generator network parameters. [netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ... trailingAvgG, trailingAvgSqG, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency iterations, display batch of generated % images using the held-out generator input. if mod(iteration,validationFrequency) == 0 || iteration == 1 % Generate images using the held-out generator input. XGeneratedValidation = predict(netG,ZValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(XGeneratedValidation)); I = rescale(I); % Display the images. image(I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the training progress monitor. recordMetrics(monitor,iteration, ... GeneratorScore=scoreG, ... DiscriminatorScore=scoreD); updateInfo(monitor,Epoch=epoch,Iteration=iteration); monitor.Progress = 100*iteration/numIterations; end end
ここでは、ディスクリミネーターは生成イメージの中から実イメージを識別する強い特徴表現を学習しました。次に、ジェネレーターは、学習データと同様のイメージを生成できるように、同様に強い特徴表現を学習しました。
学習プロットは、ジェネレーターおよびディスクリミネーターのネットワークのスコアを示しています。ネットワークのスコアを解釈する方法の詳細については、GAN の学習過程の監視と一般的な故障モードの識別を参照してください。
新しいイメージの生成
新しいイメージを生成するには、ジェネレーターに対して関数 predict
を使用して、ランダム ベクトルのバッチを含む dlarray
オブジェクトを指定します。イメージを並べて表示するには関数 imtile
を使用し、関数 rescale
を使ってイメージを再スケーリングします。
ジェネレーター ネットワークに入力するランダム ベクトル 25 個のバッチを含む dlarray
オブジェクトを作成します。
numObservations = 25; ZNew = randn(numLatentInputs,numObservations,"single"); ZNew = dlarray(ZNew,"CB");
GPU が利用可能な場合は、潜在ベクトルを gpuArray
に変換します。
if canUseGPU ZNew = gpuArray(ZNew); end
関数 predict
をジェネレーターと入力データと共に使用して、新しいイメージを生成します。
XGeneratedNew = predict(netG,ZNew);
イメージを表示します。
I = imtile(extractdata(XGeneratedNew)); I = rescale(I); figure image(I) axis off title("Generated Images")
モデル損失関数
関数 modelLoss
は、ジェネレーターおよびディスクリミネーターの dlnetwork
オブジェクトである netG
と netD
、入力データのミニバッチ X
、乱数値の配列 Z
、および実ラベルの反転する確率 flipProb
を入力として受け取り、ネットワーク内の学習可能パラメーターについての損失値と損失値の勾配、ジェネレーターの状態、および 2 つのネットワークのスコアを返します。
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ... modelLoss(netG,netD,X,Z,flipProb) % Calculate the predictions for real data with the discriminator network. YReal = forward(netD,X); % Calculate the predictions for generated data with the discriminator % network. [XGenerated,stateG] = forward(netG,Z); YGenerated = forward(netD,XGenerated); % Calculate the score of the discriminator. scoreD = (mean(YReal) + mean(1-YGenerated)) / 2; % Calculate the score of the generator. scoreG = mean(YGenerated); % Randomly flip the labels of the real images. numObservations = size(YReal,4); idx = rand(1,numObservations) < flipProb; YReal(:,:,:,idx) = 1 - YReal(:,:,:,idx); % Calculate the GAN loss. [lossG, lossD] = ganLoss(YReal,YGenerated); % For each network, calculate the gradients with respect to the loss. gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true); gradientsD = dlgradient(lossD,netD.Learnables); end
GAN の損失関数とスコア
ジェネレーターの目的はディスクリミネーターが "real"
と分類するようなデータを生成することです。ジェネレーターが生成したイメージをディスクリミネーターが実データとして分類する確率を最大化するには、負の対数尤度関数を最小化します。
ディスクリミネーターの出力 が与えられた場合、次のようになります。
は、入力イメージがクラス
"real"
に属する確率です。は、入力イメージがクラス
"generated"
に属する確率です。
ジェネレーターの損失関数は次の式で表されます。
ここで、 は生成イメージに対するディスクリミネーターの出力確率を表しています。
ディスクリミネーターの目的はジェネレーターに "騙されない" ことです。ディスクリミネーターが実イメージと生成イメージを正しく区別する確率を最大化するには、対応する負の対数尤度関数の和を最小化します。
ディスクリミネーターの損失関数は次の式で表されます。
ここで、 は実イメージに対するディスクリミネーターの出力確率を表しています。
ジェネレーターとディスクリミネーターがそれぞれの目標をどれだけ達成するかを 0 から 1 のスケールで測定するには、スコアの概念を使用できます。
ジェネレーターのスコアは、生成イメージに対するディスクリミネーターの出力に対応する確率の平均です。
ディスクリミネーターのスコアは、実イメージと生成イメージの両方に対するディスクリミネーターの出力に対応する確率の平均です。
スコアは損失に反比例しますが、実質的には同じ情報を表しています。
function [lossG,lossD] = ganLoss(YReal,YGenerated) % Calculate the loss for the discriminator network. lossD = -mean(log(YReal)) - mean(log(1-YGenerated)); % Calculate the loss for the generator network. lossG = -mean(log(YGenerated)); end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からイメージ データを抽出して数値配列に連結します。
イメージの範囲が
[-1,1]
となるように再スケーリングします。
function X = preprocessMiniBatch(data) % Concatenate mini-batch X = cat(4,data{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,InputMin=0,InputMax=255); end
参考文献
The TensorFlow Team.Flowers http://download.tensorflow.org/example_images/flower_photos.tgz
Radford, Alec, Luke Metz, and Soumith Chintala."Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks." Preprint, submitted November 19, 2015. http://arxiv.org/abs/1511.06434.
参考
dlnetwork
| forward
| predict
| dlarray
| dlgradient
| dlfeval
| adamupdate
| minibatchqueue