diff --git a/demo/gan/gan_conf.py b/demo/gan/gan_conf.py index a6943176c2235d2d94b3ecbddea7b7fd50506940..05eee3a9b9ce455eb3a5d47d3165ee7f42f1002e 100644 --- a/demo/gan/gan_conf.py +++ b/demo/gan/gan_conf.py @@ -29,9 +29,9 @@ is_discriminator = mode == "discriminator" print('mode=%s' % mode) # the dim of the noise (z) as the input of the generator network -noise_dim = 10 +noise_dim = 10 # the dim of the hidden layer -hidden_dim = 15 +hidden_dim = 10 # the dim of the generated sample sample_dim = 2 diff --git a/demo/gan/gan_trainer.py b/demo/gan/gan_trainer.py index 3f27f04fc593dc3186ac5fa9ee52464da05b6131..72699952b961cb5bf6ac14dd65eee1aeab5e2a7c 100644 --- a/demo/gan/gan_trainer.py +++ b/demo/gan/gan_trainer.py @@ -91,7 +91,7 @@ def load_mnist_data(imageFile): data = data / 255.0 * 2.0 - 1.0 f.close() - return data + return data.astype('float32') def load_cifar_data(cifar_path): batch_size = 10000