Vae in Matlab2022b
古いコメントを表示
Basically I want to recreate the VAE from this page: https://au.mathworks.com/help/deeplearning/ug/train-a-variational-autoencoder-vae-to-generate-images.html
The issue is that when I implement it (line by line, as my first test run) with a custom dataset (or with the recommended dataset), I get an error that says:
"Sampling layer is only for Train a variational autoencoder...".
Which means that the "Sampling layer" in the encoder and the "feature input layer" in the decoder does not exist for me to use.
Is there anyway I can implement a VAE in the current version of Matlab? If so what kind of layers and layer setup should I use?

trainImagesFile = "train-images-idx3-ubyte.gz";
testImagesFile = "t10k-images-idx3-ubyte.gz";
XTrain = processImagesMNIST(trainImagesFile);
XTest = processImagesMNIST(testImagesFile);
numLatentChannels = 16;
imageSize = [28 28 1];
layersE = [
imageInputLayer(imageSize,Normalization="none")
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
fullyConnectedLayer(2*numLatentChannels)
samplingLayer];
projectionSize = [7 7 64];
numInputChannels = size(imageSize,1);
layersD = [
featureInputLayer(numLatentChannels)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
numEpochs = 30;
miniBatchSize = 128;
learnRate = 1e-3;
dsTrain = arrayDatastore(XTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB", ...
PartialMiniBatch="discard");
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
numObservationsTrain = size(XTrain,4);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics="Loss", ...
Info="Epoch", ...
XLabel="Iteration");
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
monitor.Progress = 100*iteration/numIterations;
end
end
dsTest = arrayDatastore(XTest,IterationDimension=4);
numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
err = mean((XTest-YTest).^2,[1 2 3]);
figure
histogram(err)
xlabel("Error")
ylabel("Frequency")
title("Test Data")
numImages = 64;
ZNew = randn(numLatentChannels,numImages);
ZNew = dlarray(ZNew,"CB");
YNew = predict(netD,ZNew);
YNew = extractdata(YNew);
figure
I = imtile(YNew);
imshow(I)
title("Generated Images")
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
[Z,mu,logSigmaSq] = forward(netE,X);
% Forward through decoder.
Y = forward(netD,Z);
% Calculate loss and gradients.
loss = elboLoss(Y,X,mu,logSigmaSq);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = elboLoss(Y,T,mu,logSigmaSq)
% Reconstruction loss.
reconstructionLoss = mse(Y,T);
% KL divergence.
KL = -0.5 * sum(1 + logSigmaSq - mu.^2 - exp(logSigmaSq),1);
KL = mean(KL);
% Combined loss.
loss = reconstructionLoss + KL;
end
function Y = modelPredictions(netE,netD,mbq)
Y = [];
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Forward through encoder.
Z = predict(netE,X);
% Forward through dencoder.
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Y = cat(4,Y,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(dataX)
% Concatenate.
X = cat(4,dataX{:});
end
function X = processImagesMNIST(filename)
dataFolder = fullfile(tempdir,'mnist');
gunzip(filename,dataFolder)
[~,name,~] = fileparts(filename);
[fileID,errmsg] = fopen(fullfile(dataFolder,name),'r','b');
if fileID < 0
error(errmsg);
end
magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2049
fprintf('\nRead MNIST label data...\n')
end
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of labels in the dataset: %6d ...\n',numImages);
X = fread(fileID,inf,'unsigned char');
X = reshape(X,[1,size(X,1)]);
%X = reshape(X,numCols,numRows,numImages);
%X = permute(X,[2 1 3]);
%X = X./255;
%X = reshape(X, [28,28,1,size(X,3)]);
fclose(fileID);
end
6 件のコメント
Arkadiy Turevskiy
2023 年 1 月 31 日
Can you add the code so we can see where exactly the error is happening?
Glacial Claw
2023 年 1 月 31 日
Glacial Claw
2023 年 2 月 7 日
Yoann Roth
2023 年 2 月 7 日
Hi,
The layers samplingLayer and projectAndReshapeLayer are custom layers that were written for the example. They won't be automatically in the path, but MATLAB is detecting that they exist in an example folder (which is the error that you see I believe). If you want to use them, you need to make sure the files that define those custom layers are in the path using addpath, or you can just copy the files in your folder.
Can you try this and see if that helps?
Glacial Claw
2023 年 2 月 8 日
Yoann Roth
2023 年 2 月 8 日
Great! I'll add this as an answer then
採用された回答
その他の回答 (0 件)
カテゴリ
ヘルプ センター および File Exchange で Visualization and Interpretability についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!