diff --git a/demo/gan/.gitignore b/demo/gan/.gitignore index e9a33070cf40d81ec4391d63660973da040097c9..93a6f5080a16a601cffb0bff51af9aef3ba3bae7 100644 --- a/demo/gan/.gitignore +++ b/demo/gan/.gitignore @@ -1,4 +1,7 @@ output/ +uniform_params/ +cifar_params/ +mnist_params/ *.png .pydevproject .project diff --git a/demo/gan/gan_conf.py b/demo/gan/gan_conf.py index 6bd68727ba5321c0ac6fe896dc1f48e138632d69..4f57c80b77921c6d2df7bb50a90024f9356ad623 100644 --- a/demo/gan/gan_conf.py +++ b/demo/gan/gan_conf.py @@ -24,6 +24,9 @@ is_discriminator_training = mode == "discriminator_training" is_generator = mode == "generator" is_discriminator = mode == "discriminator" +# The network structure below follows the ref https://arxiv.org/abs/1406.2661 +# Here we used two hidden layers and batch_norm + print('mode=%s' % mode) # the dim of the noise (z) as the input of the generator network noise_dim = 10 diff --git a/demo/gan/gan_trainer.py b/demo/gan/gan_trainer.py index 063a98acdcf1ae9f31d5b1c8202dbc66fcf09540..09d223fba8c70beb5309ed6c7480d61f904e8bcb 100644 --- a/demo/gan/gan_trainer.py +++ b/demo/gan/gan_trainer.py @@ -90,10 +90,8 @@ def load_mnist_data(imageFile): data = numpy.zeros((n, 28*28), dtype = "float32") for i in range(n): - pixels = [] - for j in range(28 * 28): - pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0) - data[i, :] = pixels + pixels = numpy.fromfile(f, 'ubyte', count=28*28) + data[i, :] = pixels / 255.0 * 2.0 - 1.0 f.close() return data @@ -129,7 +127,7 @@ def merge(images, size): ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) return img.astype('uint8') -def saveImages(images, path): +def save_images(images, path): merged_img = merge(images, [8, 8]) if merged_img.shape[2] == 1: im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') @@ -207,9 +205,15 @@ def main(): useGpu = args.useGpu assert dataSource in ["mnist", "cifar", "uniform"] assert useGpu in ["0", "1"] - + + if not os.path.exists("./%s_samples/" % dataSource): + os.makedirs("./%s_samples/" % dataSource) + + if not os.path.exists("./%s_params/" % dataSource): + os.makedirs("./%s_params/" % dataSource) + api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', - '--gpu_id=' + args.gpuId) + '--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource) if dataSource == "uniform": conf = "gan_conf.py" @@ -231,9 +235,6 @@ def main(): else: data_np = load_uniform_data() - if not os.path.exists("./%s_samples/" % dataSource): - os.makedirs("./%s_samples/" % dataSource) - # this create a gradient machine for discriminator dis_training_machine = api.GradientMachine.createFromConfigProto( dis_conf.model_config) @@ -321,7 +322,7 @@ def main(): if dataSource == "uniform": plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) else: - saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) + save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) dis_trainer.finishTrain() gen_trainer.finishTrain()