提交 531e8354 编写于 作者: W wangyang59

changes to demo/gan following lzhao4ever comments

上级 9a02bd41
output/
uniform_params/
cifar_params/
mnist_params/
*.png
.pydevproject
.project
......
......@@ -24,6 +24,9 @@ is_discriminator_training = mode == "discriminator_training"
is_generator = mode == "generator"
is_discriminator = mode == "discriminator"
# The network structure below follows the ref https://arxiv.org/abs/1406.2661
# Here we used two hidden layers and batch_norm
print('mode=%s' % mode)
# the dim of the noise (z) as the input of the generator network
noise_dim = 10
......
......@@ -90,10 +90,8 @@ def load_mnist_data(imageFile):
data = numpy.zeros((n, 28*28), dtype = "float32")
for i in range(n):
pixels = []
for j in range(28 * 28):
pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0)
data[i, :] = pixels
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
data[i, :] = pixels / 255.0 * 2.0 - 1.0
f.close()
return data
......@@ -129,7 +127,7 @@ def merge(images, size):
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
return img.astype('uint8')
def saveImages(images, path):
def save_images(images, path):
merged_img = merge(images, [8, 8])
if merged_img.shape[2] == 1:
im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB')
......@@ -207,9 +205,15 @@ def main():
useGpu = args.useGpu
assert dataSource in ["mnist", "cifar", "uniform"]
assert useGpu in ["0", "1"]
if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource)
if not os.path.exists("./%s_params/" % dataSource):
os.makedirs("./%s_params/" % dataSource)
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId)
'--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource)
if dataSource == "uniform":
conf = "gan_conf.py"
......@@ -231,9 +235,6 @@ def main():
else:
data_np = load_uniform_data()
if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource)
# this create a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config)
......@@ -321,7 +322,7 @@ def main():
if dataSource == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
else:
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
dis_trainer.finishTrain()
gen_trainer.finishTrain()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册