Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

変分自己符号化器 (VAE) の学習によるイメージ生成

この例では、MATLAB で変分自己符号化器 (VAE) を作成して数字のイメージを生成する方法を説明します。VAE は、手書きの数字を MNIST データセットのスタイルで生成します。

VAE は入力の再構成に符号化と復号化の処理を行わないという点で、通常の自己符号化器と異なります。代わりに、潜在空間上に確率分布を適用してその分布を学習することで、復号化器からの出力の分布と観測データの分布を一致させます。その後、この分布からサンプリングして新しいデータを生成します。

この例では、VAE ネットワークを構築し、MNIST データセットで学習させ、データセット内のイメージによく似た新しいイメージを生成します。

データの読み込み

MNIST ファイルを http://yann.lecun.com/exdb/mnist/ からダウンロードし、MNIST データセットをワークスペースに読み込みます [1]。この例に添付されている補助関数 processImagesMNIST と補助関数 processLabelsMNIST を呼び出して、ファイルから MATLAB 配列にデータを読み込みます。

VAE は、再構築された数字を入力と比較し、カテゴリカル ラベルと比較するのではないため、MNIST データセットの学習ラベルを使用する必要はありません。

trainImagesFile = 'train-images-idx3-ubyte.gz';
testImagesFile = 't10k-images-idx3-ubyte.gz';
testLabelsFile = 't10k-labels-idx1-ubyte.gz';

XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data...
Number of images in the dataset:  60000 ...
numTrainImages = size(XTrain,4);
XTest = processImagesMNIST(testImagesFile);
Read MNIST image data...
Number of images in the dataset:  10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data...
Number of labels in the dataset:  10000 ...

ネットワークの構築

自己符号化器には、符号化器と復号化器の 2 つの部分があります。符号化器は、イメージ入力を受け取り、圧縮された表現を出力 (符号化) します。これはサイズが latentDim (この例では 20) のベクトルです。復号化器は圧縮表現を受け取り、復号化して元のイメージを再作成します。

計算を数値的に安定させるには、ネットワークに分散の対数から学習させて、取りうる値の範囲を [0,1] から [-inf, 0] に増やします。サイズ latent_dim の 2 つのベクトルを定義します。1 つは平均 μ、もう 1 つは分散の対数 log(σ2) です。次にこれら 2 つのベクトルを使用して、サンプリングの元になる分布を作成します。

2 次元畳み込みと、その後に続く全結合層を使用して、28 x 28 x 1 の MNIST イメージを潜在空間における符号化にダウンサンプリングします。その後、転置 2 次元畳み込みを使用して 1 x 1 x 20 の符号化を 28 x 28 x 1 のイメージにスケールアップして戻します。

latentDim = 20;
imageSize = [28 28 1];

encoderLG = layerGraph([
    imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')
    convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder')
    ]);

decoderLG = layerGraph([
    imageInputLayer([1 1 latentDim],'Name','i','Normalization','none')
    transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
    ]);

カスタム学習ループを使用して両方のネットワークに学習させて自動微分を有効にするために、層グラフを dlnetwork オブジェクトに変換します。

encoderNet = dlnetwork(encoderLG);
decoderNet = dlnetwork(decoderLG);

モデル勾配関数の定義

補助関数 modelGradients は、符号化器と復号化器の dlnetwork オブジェクトと入力データのミニバッチ X を受け取り、ネットワーク内の学習可能なパラメーターについての損失の勾配を返します。この補助関数の定義は、この例の終わりで行います。

関数はこの処理をサンプリングと損失の 2 つの手順で実行します。サンプリングの手順では、平均と分散のベクトルをサンプリングして、復号化器ネットワークに渡される最終的な符号化を作成します。ただし、無作為抽出操作の逆伝播は不可能なので、"再パラメーター化のトリック" を使用しなければなりません。このトリックでは、無作為抽出操作を補助変数 ε に移動します。その後、これを平均 μi でシフトし、標準偏差 σi でスケーリングします。ここでは、N(μi,σi2) からのサンプリングは μi+εσi からのサンプリングと同じであると考えます。ここで、εN(0,1) です。次の図に、この考え方を視覚的に示します。

損失の手順では、サンプリング ステップで生成された符号化を復号化器ネットワークに渡し、損失を判定し、これを使って勾配を計算します。VAE の損失は、変分下限 (ELBO) 損失とも呼ばれ、2 つの個別の損失項の和として定義されます。

ELBO loss=reconstruction loss+KL loss.

"再構成損失" は、復号化器の出力が元の入力にどれだけ近いかを、平均二乗誤差 (MSE) を使用して測定します。

reconstruction loss=MSE(decoder output,original image).

"KL 損失"、つまりカルバック・ライブラー ダイバージェンスは、2 つの確率分布の差を測定します。この場合、KL 損失の最小化によって、学習した平均と分散がターゲット (正規) 分布にできる限り近づくようにします。サイズ n の潜在次元について、KL 損失は次のように取得されます。

KL loss=-0.5i=1n(1+log(σi2)-μi2-σi2).

KL 損失項を含めることで、再構成の損失により学習したクラスターを潜在空間の中心の周りに密集させて、サンプリング元となる連続空間を構成するという実用的な効果が得られます。

学習オプションの指定

GPU が利用できる場合、GPU で学習を行います (Parallel Computing Toolbox™ が必要です)。

executionEnvironment = "auto";

ネットワークの学習オプションを設定します。Adam オプティマイザーを使用する場合は、各ネットワークについて、移動平均勾配と移動平均 2 乗勾配の減衰率を空の配列に初期化する必要があります

numEpochs = 50;
miniBatchSize = 512;
lr = 1e-3;
numIterations = floor(numTrainImages/miniBatchSize);
iteration = 0;

avgGradientsEncoder = [];
avgGradientsSquaredEncoder = [];
avgGradientsDecoder = [];
avgGradientsSquaredDecoder = [];

モデルの学習

カスタム学習ループを使用してモデルに学習させます。

エポック内のそれぞれの反復で、次を行います。

  • 学習セットから次のミニバッチを取得。

  • ミニバッチを dlarray オブジェクトに変換。次元ラベルに 'SSCB' (spatial、spatial、channel、batch) を指定していることを確認してください。

  • GPU で学習する場合、dlarraygpuArray オブジェクトに変換。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配を評価。

  • 関数 adamupdate を使用して、両方のネットワークの学習可能項目と平均勾配を更新。

各エポックの最後に、自己符号化器にテスト セットのイメージを渡し、そのエポックの損失と学習時間を表示します。

for epoch = 1:numEpochs
    tic;
    for i = 1:numIterations
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        XBatch = XTrain(:,:,:,idx);
        XBatch = dlarray(single(XBatch), 'SSCB');
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            XBatch = gpuArray(XBatch);           
        end 
            
        [infGrad, genGrad] = dlfeval(...
            @modelGradients, encoderNet, decoderNet, XBatch);
        
        [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
            adamupdate(decoderNet.Learnables, ...
                genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
        [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
            adamupdate(encoderNet.Learnables, ...
                infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
    end
    elapsedTime = toc;
    
    [z, zMean, zLogvar] = sampling(encoderNet, XTest);
    xPred = sigmoid(forward(decoderNet, z));
    elbo = ELBOloss(XTest, xPred, zMean, zLogvar);
    disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+...
        ". Time taken for epoch = "+ elapsedTime + "s")    
end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s
Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s
Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s
Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s
Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s
Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s
Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s
Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s
Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s
Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s
Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s
Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s
Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s
Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s
Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s
Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s
Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s
Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s
Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s
Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s
Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s
Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s
Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s
Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s
Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s
Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s
Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s
Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s
Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s
Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s
Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s
Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s
Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s
Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s
Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s
Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s
Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s
Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s
Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s
Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s
Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s
Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s
Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s
Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s
Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s
Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s
Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s
Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s
Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s
Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s

結果の可視化

結果を可視化して解釈するには、補助の可視化関数使用します。これらの補助関数の定義は、この例の最後にあります。

関数 VisualizeReconstruction は、各クラスからランダムに選択された数字と、自己符号化器を通過した後の再構築を一緒に表示します。

関数 VisualizeLatentSpace は、テスト イメージを符号化器ネットワークに渡した後に生成された (それぞれ次元が 20 の) 平均と分散の符号化を受け取り、各イメージの符号化を含む行列に対して主成分分析 (PCA) を実行します。その後、最初の 2 つの主成分によって特徴付けられた 2 つの次元における平均と分散で定義された潜在空間を可視化することができます。

関数 Generate は、正規分布からサンプリングされた新しい符号化を初期化し、これらの符号化が復号化器ネットワークを通過するときに生成されたイメージを出力します。

visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

generate(decoderNet, latentDim)

次のステップ

生成タスクの実行に使用できるモデルは数多くあり、変分自己符号化器はその 1 つに過ぎません。変分自己符号化器は、イメージが小さく、特徴が明確に定義されているデータセット (MNIST など) に対して適切に機能します。イメージが大きく、より複雑なデータセットの場合は、敵対的生成ネットワーク (GAN) の方がより効果的に機能し、ノイズの少ないイメージを生成する傾向にあります。GAN を実装して 64 x 64 の RGB イメージを生成する方法を示す例については、敵対的生成ネットワーク (GAN) の学習を参照してください。

参考文献

  1. LeCun, Y., C. Cortes, and C. J. C. Burges. "The MNIST Database of Handwritten Digits." http://yann.lecun.com/exdb/mnist/.

補助関数

モデル勾配関数

関数 modelGradients は符号化器と復号化器の dlnetwork オブジェクトと、入力データ X のミニバッチを受け取り、学習可能なネットワーク パラメーターについての損失の勾配を返します。この関数は 3 つの処理を行います。

  1. 符号化器ネットワークに渡されたイメージのミニバッチに対して関数 sampling を呼び出し、符号化を取得する。

  2. 復号化器ネットワークに符号化を渡して関数 ELBOloss を呼び出し、損失を取得する。

  3. 関数 dlgradient を呼び出し、両方のネットワークに対して学習可能なパラメーターについての損失の勾配を計算する。

function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

サンプリングと損失関数

関数 sampling は、入力イメージから符号化を取得します。最初に、イメージのミニバッチを符号化器ネットワークに渡して、サイズ (2*latentDim)*miniBatchSize の出力を平均の行列と分散の行列に分割します。それぞれのサイズは latentDim*batchSize です。次に、これらの行列を使用して再パラメーター化トリックの実装と符号化の計算を行います。最後に、符号化を SCCB フォーマットの dlarray オブジェクトに変換します。

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

関数 ELBOloss は、関数 sampling から返された平均と分散の符号化を受け取り、ELBO 損失の計算に使用します。

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

可視化関数

関数 VisualizeReconstruction は、MNIST データセットの各数字に対して 2 つのイメージをランダムに選択して VAE に渡し、元の入力と並べて再構成をプロットします。dlarray オブジェクト内に含まれる情報をプロットするには、最初に関数 extractdata と関数 gather を使用して抽出する必要があることに注意してください。

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

関数 VisualizeLatentSpace は、符号化ネットワークの出力を構成する平均と分散の行列によって定義される潜在空間を可視化し、各数字の潜在空間表現で構成されるクラスターを配置します。

この関数は、最初に dlarray オブジェクトから平均と分散の行列を抽出します。チャネル/バッチ次元 (C と B) をもつ行列の転置は不可能であるため、関数は行列を転置する前に stripdims を呼び出します。次に、両方の行列に対して主成分分析 (PCA) を行います。潜在空間を 2 次元で可視化するために、関数は最初の 2 つの主成分を保持し、これらを互いに対してプロットします。最後に、数字のクラスを色付けしてクラスターを観察できるようにします。

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

関数 generate は、VAE の生成能力をテストします。ランダムに生成された 25 個の符号化を含む dlarray オブジェクトを初期化し、復号化器ネットワークに渡して出力をプロットします。

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

参考

| | | | | |

関連するトピック