提交 531e8354 编写于 作者: W wangyang59

changes to demo/gan following lzhao4ever comments

上级 9a02bd41
output/ output/
uniform_params/
cifar_params/
mnist_params/
*.png *.png
.pydevproject .pydevproject
.project .project
......
...@@ -24,6 +24,9 @@ is_discriminator_training = mode == "discriminator_training" ...@@ -24,6 +24,9 @@ is_discriminator_training = mode == "discriminator_training"
is_generator = mode == "generator" is_generator = mode == "generator"
is_discriminator = mode == "discriminator" is_discriminator = mode == "discriminator"
# The network structure below follows the ref https://arxiv.org/abs/1406.2661
# Here we used two hidden layers and batch_norm
print('mode=%s' % mode) print('mode=%s' % mode)
# the dim of the noise (z) as the input of the generator network # the dim of the noise (z) as the input of the generator network
noise_dim = 10 noise_dim = 10
......
...@@ -90,10 +90,8 @@ def load_mnist_data(imageFile): ...@@ -90,10 +90,8 @@ def load_mnist_data(imageFile):
data = numpy.zeros((n, 28*28), dtype = "float32") data = numpy.zeros((n, 28*28), dtype = "float32")
for i in range(n): for i in range(n):
pixels = [] pixels = numpy.fromfile(f, 'ubyte', count=28*28)
for j in range(28 * 28): data[i, :] = pixels / 255.0 * 2.0 - 1.0
pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0)
data[i, :] = pixels
f.close() f.close()
return data return data
...@@ -129,7 +127,7 @@ def merge(images, size): ...@@ -129,7 +127,7 @@ def merge(images, size):
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
return img.astype('uint8') return img.astype('uint8')
def saveImages(images, path): def save_images(images, path):
merged_img = merge(images, [8, 8]) merged_img = merge(images, [8, 8])
if merged_img.shape[2] == 1: if merged_img.shape[2] == 1:
im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB')
...@@ -208,8 +206,14 @@ def main(): ...@@ -208,8 +206,14 @@ def main():
assert dataSource in ["mnist", "cifar", "uniform"] assert dataSource in ["mnist", "cifar", "uniform"]
assert useGpu in ["0", "1"] assert useGpu in ["0", "1"]
if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource)
if not os.path.exists("./%s_params/" % dataSource):
os.makedirs("./%s_params/" % dataSource)
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId) '--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource)
if dataSource == "uniform": if dataSource == "uniform":
conf = "gan_conf.py" conf = "gan_conf.py"
...@@ -231,9 +235,6 @@ def main(): ...@@ -231,9 +235,6 @@ def main():
else: else:
data_np = load_uniform_data() data_np = load_uniform_data()
if not os.path.exists("./%s_samples/" % dataSource):
os.makedirs("./%s_samples/" % dataSource)
# this create a gradient machine for discriminator # this create a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto( dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config) dis_conf.model_config)
...@@ -321,7 +322,7 @@ def main(): ...@@ -321,7 +322,7 @@ def main():
if dataSource == "uniform": if dataSource == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
else: else:
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册