提交 5aa59796 编写于 作者: W wangyang59

minor changes on demo/gan following lzhao4ever comments

上级 531e8354
...@@ -9,4 +9,5 @@ Then you can run the command below. The flag -d specifies the training data (cif ...@@ -9,4 +9,5 @@ Then you can run the command below. The flag -d specifies the training data (cif
$python gan_trainer.py -d cifar --useGpu 1 $python gan_trainer.py -d cifar --useGpu 1
The generated images will be stored in ./cifar_samples/ The generated images will be stored in ./cifar_samples/
\ No newline at end of file The corresponding models will be stored in ./cifar_params/
\ No newline at end of file
#!/usr/bin/env sh #!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it. # This script downloads the mnist data and unzips it.
set -e set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )" DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/mnist_data" rm -rf "$DIR/mnist_data"
......
...@@ -38,7 +38,7 @@ sample_dim = 2 ...@@ -38,7 +38,7 @@ sample_dim = 2
settings( settings(
batch_size=128, batch_size=128,
learning_rate=1e-4, learning_rate=1e-4,
learning_method=AdamOptimizer(beta1=0.7) learning_method=AdamOptimizer(beta1=0.5)
) )
def discriminator(sample): def discriminator(sample):
......
...@@ -87,11 +87,8 @@ def load_mnist_data(imageFile): ...@@ -87,11 +87,8 @@ def load_mnist_data(imageFile):
else: else:
n = 10000 n = 10000
data = numpy.zeros((n, 28*28), dtype = "float32") data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28))
data = data / 255.0 * 2.0 - 1.0
for i in range(n):
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
data[i, :] = pixels / 255.0 * 2.0 - 1.0
f.close() f.close()
return data return data
...@@ -235,7 +232,7 @@ def main(): ...@@ -235,7 +232,7 @@ def main():
else: else:
data_np = load_uniform_data() data_np = load_uniform_data()
# this create a gradient machine for discriminator # this creates a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto( dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config) dis_conf.model_config)
# this create a gradient machine for generator # this create a gradient machine for generator
...@@ -243,7 +240,7 @@ def main(): ...@@ -243,7 +240,7 @@ def main():
gen_conf.model_config) gen_conf.model_config)
# generator_machine is used to generate data only, which is used for # generator_machine is used to generate data only, which is used for
# training discrinator # training discriminator
logger.info(str(generator_conf.model_config)) logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto( generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config) generator_conf.model_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册