提交 5aa59796 编写于 作者: W wangyang59

minor changes on demo/gan following lzhao4ever comments

上级 531e8354
......@@ -10,3 +10,4 @@ Then you can run the command below. The flag -d specifies the training data (cif
$python gan_trainer.py -d cifar --useGpu 1
The generated images will be stored in ./cifar_samples/
The corresponding models will be stored in ./cifar_params/
\ No newline at end of file
#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.
# This script downloads the mnist data and unzips it.
set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/mnist_data"
......
......@@ -38,7 +38,7 @@ sample_dim = 2
settings(
batch_size=128,
learning_rate=1e-4,
learning_method=AdamOptimizer(beta1=0.7)
learning_method=AdamOptimizer(beta1=0.5)
)
def discriminator(sample):
......
......@@ -87,11 +87,8 @@ def load_mnist_data(imageFile):
else:
n = 10000
data = numpy.zeros((n, 28*28), dtype = "float32")
for i in range(n):
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
data[i, :] = pixels / 255.0 * 2.0 - 1.0
data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28))
data = data / 255.0 * 2.0 - 1.0
f.close()
return data
......@@ -235,7 +232,7 @@ def main():
else:
data_np = load_uniform_data()
# this create a gradient machine for discriminator
# this creates a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config)
# this create a gradient machine for generator
......@@ -243,7 +240,7 @@ def main():
gen_conf.model_config)
# generator_machine is used to generate data only, which is used for
# training discrinator
# training discriminator
logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册