フィルターのクリア

How to find the optimal network architecture for Generative Adversarial Network (GAN) with time sequence input

10 ビュー (過去 30 日間)
I am currently designing the architecture of a vanilla GAN. The input to the network is a data time sequence consisting of 19 features and 2001 time steps. I used a custom mini-batch function for time sequences. Also I used the iteration loop found in the ''Train Generative Adversarial Network (GAN)' example from Mathworks.
Currently the obtained score for both the generator and discriminator are around 0.5. I would expect to score of the generator to start around 0 and that of the discriminator slightly above 0.5. I expect the problem to be in the discriminator network architecture. My question is how to find out which components the discriminator has to have?
The code is shown below:
numLatentInputs = 10; %Input oise that is used by generator
numGRU = 500; %Number of neurons in one hidden layer
numFeatures = 19; %Same size as training data
numResponses = 1; %Size of probility real or fake
Dropout = 0.5;
%% Create generator
layersGenerator = [ ...
sequenceInputLayer(numLatentInputs,"Name","InputSequence_G")
gruLayer(numGRU,"Name","GRU1_G","OutputMode","sequence")
gruLayer(numGRU,"Name","GRU2_G","OutputMode","sequence")
gruLayer(numGRU,"Name","GRU3_G","OutputMode","sequence")
dropoutLayer(DropOut,"Name","dropout")
fullyConnectedLayer(numFeatures,"Name","FullyConnected")];
netG = dlnetwork(layersGenerator); %Convert copy to dl network
%% Create discriminator
layersDiscriminator = [
sequenceInputLayer(numFeatures,"Name","InputSequence_D")
gruLayer(numGRU,"Name","GRU1_D","OutputMode","sequence")
gruLayer(numGRU,"Name","GRU2_D","OutputMode","sequence")
gruLayer(numGRU,"Name","GRU3_D","OutputMode","sequence")
dropoutLayer(DropOut,"Name","DropoutLayer")
fullyConnectedLayer(numResponses,"Name","FullyConnected")
sigmoidLayer("Name","sigmoid")];
netD = dlnetwork(layersDiscriminator); %Convert copy to dl network
Also I obtain the following error. The real training data and generated are divided into mini-batches with size 15. Now for validation I create a seperate data set, consisting of the same amount of features and time steps, but with 25 oberservations. The following error occurs when using the validation data set as input to the 'predict' function. I don't understand why the network would expect a specific mini-batch size at all.
Error using dlnetwork/predict (line 664)
Layer 'sequence': Invalid input data. Incorrect network state. The network expects mini-batches size of 15, but was passed a
mini-batch of size 25.
Please let me know if you have any ideas. Thank you in advance!

回答 (1 件)

Divit
Divit 2024 年 1 月 29 日
Hi Chris,
Finding the optimal network architecture for a Generative Adversarial Network (GAN), especially with time sequence input, is an iterative and experimental process. There is no one-size-fits-all solution, and the architecture often depends on the specific characteristics of the data and the task at hand. Here are some general guidelines and steps you can take to refine your GAN architecture:
Refining the Discriminator Architecture
  • Start Simple: Begin with a simpler architecture and gradually increase complexity. For time series data, you might start with one or two GRU layers before adding more.
  • Balance Capacity: Ensure that the discriminator and generator are well-balanced in terms of capacity. If the discriminator is too powerful, it will easily distinguish real from fake data, and the generator won't learn effectively.
  • Regularization: Use dropout and other regularization techniques to prevent overfitting, but don't overdo it as it might impair the learning of the discriminator.
  • Batch Normalization: Consider adding batch normalization to stabilize learning and help with the gradient flow.
  • Adjust Learning Rates: Sometimes, using different learning rates for the generator and discriminator can help balance the training process.
Dealing with the Error
The error you're encountering suggests that the network expects a specific mini-batch size during prediction based on how it was trained. In MATLAB, when using sequence networks like LSTM or GRU, the mini-batch size is typically flexible, but if you've configured your network to be stateful, it might expect a consistent mini-batch size.
To resolve this error, consider the following:
  • Stateless Networks: Ensure your GRU layers are stateless. In MATLAB, stateless RNN layers do not maintain state between mini-batches, allowing for variable mini-batch sizes during prediction.
  • Reset States: If your network is stateful, you may need to reset the states of the network before making predictions with a new mini-batch size.
  • Consistent Mini-Batch Size: Try to use the same mini-batch size for training, validation, and testing. If this is not possible, you might need to retrain the network to be stateless or handle variable mini-batch sizes.
Note that, There is a typo in your code. It should be "Dropout" instead of "DropOut" when passing it as an argument to the "dropoutLayer". Here's the corrected line:
dropoutLayer(Dropout,"Name","dropout")
To learn more about GANs you can refer to the following documentation link:

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

製品


リリース

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by