From d8aada072b3dad38f8bb4cd9d4878376cb2d3e2b Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Mon, 14 Nov 2016 16:34:23 -0800 Subject: [PATCH] added cifar data into dema/gan --- demo/gan/.gitignore | 4 +- demo/gan/data/download_cifar.sh | 18 +++++++ demo/gan/gan_conf_image.py | 11 ++-- demo/gan/gan_trainer_image.py | 89 ++++++++++++++++++++------------- 4 files changed, 82 insertions(+), 40 deletions(-) create mode 100755 demo/gan/data/download_cifar.sh diff --git a/demo/gan/.gitignore b/demo/gan/.gitignore index 150fa0ab54..f03677e753 100644 --- a/demo/gan/.gitignore +++ b/demo/gan/.gitignore @@ -2,5 +2,7 @@ output/ *.png .pydevproject .project -train.log +*.log +*.pyc data/raw_data/ +data/cifar-10-batches-py/ diff --git a/demo/gan/data/download_cifar.sh b/demo/gan/data/download_cifar.sh new file mode 100755 index 0000000000..ea3be594cd --- /dev/null +++ b/demo/gan/data/download_cifar.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +tar zxf cifar-10-python.tar.gz +rm cifar-10-python.tar.gz + diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index e811bb96e8..0c3f3a343b 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from paddle.trainer_config_helpers import * -from paddle.trainer_config_helpers.activations import LinearActivation -from numpy.distutils.system_info import tmp mode = get_config_arg("mode", str, "generator") +dataSource = get_config_arg("data", str, "mnist") assert mode in set(["generator", "discriminator", "generator_training", @@ -30,8 +29,12 @@ print('mode=%s' % mode) noise_dim = 100 gf_dim = 64 df_dim = 64 -sample_dim = 28 # image dim -c_dim = 1 # image color +if dataSource == "mnist": + sample_dim = 28 # image dim + c_dim = 1 # image color +else: + sample_dim = 32 + c_dim = 3 s2, s4 = int(sample_dim/2), int(sample_dim/4), s8, s16 = int(sample_dim/8), int(sample_dim/16) diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py index 9c7ddd4796..e8ed218663 100644 --- a/demo/gan/gan_trainer_image.py +++ b/demo/gan/gan_trainer_image.py @@ -16,31 +16,13 @@ import argparse import itertools import random import numpy +import cPickle import sys,os,gc from PIL import Image from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import logger import py_paddle.swig_paddle as api -from py_paddle import DataProviderConverter - -import matplotlib.pyplot as plt - - -def plot2DScatter(data, outputfile): - # Generate some test data - x = data[:, 0] - y = data[:, 1] - print "The mean vector is %s" % numpy.mean(data, 0) - print "The std vector is %s" % numpy.std(data, 0) - - heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) - extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] - - plt.clf() - plt.scatter(x, y) - # plt.show() - plt.savefig(outputfile, bbox_inches='tight') def CHECK_EQ(a, b): assert a == b, "a=%s, b=%s" % (a, b) @@ -94,18 +76,39 @@ def load_mnist_data(imageFile): f.close() return data +def load_cifar_data(cifar_path): + batch_size = 10000 + data = numpy.zeros((5*batch_size, 32*32*3), dtype = "float32") + for i in range(1, 6): + file = cifar_path + "/data_batch_" + str(i) + fo = open(file, 'rb') + dict = cPickle.load(fo) + fo.close() + data[(i - 1)*batch_size:(i*batch_size), :] = dict["data"] + + data = data / 255.0 * 2.0 - 1.0 + return data + def merge(images, size): - h, w = 28, 28 - img = numpy.zeros((h * size[0], w * size[1])) + if images.shape[1] == 28*28: + h, w, c = 28, 28, 1 + else: + h, w, c = 32, 32, 3 + img = numpy.zeros((h * size[0], w * size[1], c)) for idx in xrange(size[0] * size[1]): i = idx % size[1] j = idx // size[1] - img[j*h:j*h+h, i*w:i*w+w] = (images[idx, :].reshape((h, w)) + 1.0) / 2.0 * 255.0 - return img + #img[j*h:j*h+h, i*w:i*w+w, :] = (images[idx, :].reshape((h, w, c), order="F") + 1.0) / 2.0 * 255.0 + img[j*h:j*h+h, i*w:i*w+w, :] = \ + ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) + return img.astype('uint8') def saveImages(images, path): merged_img = merge(images, [8, 8]) - im = Image.fromarray(merged_img).convert('RGB') + if merged_img.shape[2] == 1: + im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') + else: + im = Image.fromarray(merged_img, mode="RGB") im.save(path) def get_real_samples(batch_size, data_np): @@ -115,9 +118,9 @@ def get_real_samples(batch_size, data_np): def get_noise(batch_size, noise_dim): return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') -def get_sample_noise(batch_size): - return numpy.random.normal(size=(batch_size, 28*28), - scale=0.1).astype('float32') +def get_sample_noise(batch_size, sample_dim): + return numpy.random.normal(size=(batch_size, sample_dim), + scale=0.01).astype('float32') def get_fake_samples(generator_machine, batch_size, noise): gen_inputs = api.Arguments.createArguments(1) @@ -177,15 +180,31 @@ def get_layer_size(model_conf, layer_name): def main(): - api.initPaddle('--use_gpu=1', '--dot_period=10', '--log_period=100') - gen_conf = parse_config("gan_conf_image.py", "mode=generator_training") - dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training") - generator_conf = parse_config("gan_conf_image.py", "mode=generator") + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--dataSource", help="mnist or cifar") + parser.add_argument("--useGpu", default="1", + help="1 means use gpu for training") + args = parser.parse_args() + dataSource = args.dataSource + useGpu = args.useGpu + assert dataSource in ["mnist", "cifar"] + assert useGpu in ["0", "1"] + + api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100') + gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource) + dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource) + generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource) batch_size = dis_conf.opt_config.batch_size noise_dim = get_layer_size(gen_conf.model_config, "noise") sample_dim = get_layer_size(dis_conf.model_config, "sample") - data_np = load_mnist_data("./data/raw_data/train-images-idx3-ubyte") + if dataSource == "mnist": + data_np = load_mnist_data("./data/raw_data/train-images-idx3-ubyte") + else: + data_np = load_cifar_data("./data/cifar-10-batches-py/") + + if not os.path.exists("./%s_samples/" % dataSource): + os.makedirs("./%s_samples/" % dataSource) # this create a gradient machine for discriminator dis_training_machine = api.GradientMachine.createFromConfigProto( @@ -224,12 +243,12 @@ def main(): # generator_machine, batch_size, noise_dim, sample_dim) # dis_loss = get_training_loss(dis_training_machine, data_batch_dis) noise = get_noise(batch_size, noise_dim) - sample_noise = get_sample_noise(batch_size) + sample_noise = get_sample_noise(batch_size, sample_dim) data_batch_dis_pos = prepare_discriminator_data_batch_pos( batch_size, data_np, sample_noise) dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) - sample_noise = get_sample_noise(batch_size) + sample_noise = get_sample_noise(batch_size, sample_dim) data_batch_dis_neg = prepare_discriminator_data_batch_neg( generator_machine, batch_size, noise, sample_noise) dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) @@ -271,7 +290,7 @@ def main(): fake_samples = get_fake_samples(generator_machine, batch_size, noise) - saveImages(fake_samples, "train_pass%s.png" % train_pass) + saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) dis_trainer.finishTrain() gen_trainer.finishTrain() -- GitLab