From 5aa597960ddb1884aa0c7399890726c6bdb9d5eb Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Tue, 29 Nov 2016 10:35:54 -0800 Subject: [PATCH] minor changes on demo/gan following lzhao4ever comments --- demo/gan/README.md | 3 ++- demo/gan/data/get_mnist_data.sh | 2 +- demo/gan/gan_conf.py | 2 +- demo/gan/gan_trainer.py | 11 ++++------- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/demo/gan/README.md b/demo/gan/README.md index f347f15e2d4..1ec1afa0ba0 100644 --- a/demo/gan/README.md +++ b/demo/gan/README.md @@ -9,4 +9,5 @@ Then you can run the command below. The flag -d specifies the training data (cif $python gan_trainer.py -d cifar --useGpu 1 -The generated images will be stored in ./cifar_samples/ \ No newline at end of file +The generated images will be stored in ./cifar_samples/ +The corresponding models will be stored in ./cifar_params/ \ No newline at end of file diff --git a/demo/gan/data/get_mnist_data.sh b/demo/gan/data/get_mnist_data.sh index 21fd9badc7f..d21bf706713 100644 --- a/demo/gan/data/get_mnist_data.sh +++ b/demo/gan/data/get_mnist_data.sh @@ -1,5 +1,5 @@ #!/usr/bin/env sh -# This scripts downloads the mnist data and unzips it. +# This script downloads the mnist data and unzips it. set -e DIR="$( cd "$(dirname "$0")" ; pwd -P )" rm -rf "$DIR/mnist_data" diff --git a/demo/gan/gan_conf.py b/demo/gan/gan_conf.py index 4f57c80b779..a6943176c22 100644 --- a/demo/gan/gan_conf.py +++ b/demo/gan/gan_conf.py @@ -38,7 +38,7 @@ sample_dim = 2 settings( batch_size=128, learning_rate=1e-4, - learning_method=AdamOptimizer(beta1=0.7) + learning_method=AdamOptimizer(beta1=0.5) ) def discriminator(sample): diff --git a/demo/gan/gan_trainer.py b/demo/gan/gan_trainer.py index 09d223fba8c..572b05f771d 100644 --- a/demo/gan/gan_trainer.py +++ b/demo/gan/gan_trainer.py @@ -87,11 +87,8 @@ def load_mnist_data(imageFile): else: n = 10000 - data = numpy.zeros((n, 28*28), dtype = "float32") - - for i in range(n): - pixels = numpy.fromfile(f, 'ubyte', count=28*28) - data[i, :] = pixels / 255.0 * 2.0 - 1.0 + data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)) + data = data / 255.0 * 2.0 - 1.0 f.close() return data @@ -235,7 +232,7 @@ def main(): else: data_np = load_uniform_data() - # this create a gradient machine for discriminator + # this creates a gradient machine for discriminator dis_training_machine = api.GradientMachine.createFromConfigProto( dis_conf.model_config) # this create a gradient machine for generator @@ -243,7 +240,7 @@ def main(): gen_conf.model_config) # generator_machine is used to generate data only, which is used for - # training discrinator + # training discriminator logger.info(str(generator_conf.model_config)) generator_machine = api.GradientMachine.createFromConfigProto( generator_conf.model_config) -- GitLab