From 575b888130337bde041bfa986b9451fee62c2210 Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Thu, 17 Nov 2016 15:37:51 -0800 Subject: [PATCH] add readme and comments in demo/gan --- demo/gan/README.md | 12 ++ demo/gan/gan_conf.py | 6 +- demo/gan/gan_conf_image.py | 20 ++- demo/gan/gan_trainer.py | 235 +++++++++++++++++-------- demo/gan/gan_trainer_image.py | 313 ---------------------------------- 5 files changed, 200 insertions(+), 386 deletions(-) create mode 100644 demo/gan/README.md delete mode 100644 demo/gan/gan_trainer_image.py diff --git a/demo/gan/README.md b/demo/gan/README.md new file mode 100644 index 0000000000..51faba1919 --- /dev/null +++ b/demo/gan/README.md @@ -0,0 +1,12 @@ +# Generative Adversarial Networks (GAN) + +This demo implements GAN training described in the original GAN paper (https://arxiv.org/abs/1406.2661) and DCGAN (https://arxiv.org/abs/1511.06434). + +The general training procedures are implemented in gan_trainer.py. The neural network configurations are specified in gan_conf.py (for synthetic data) and gan_conf_image.py (for image data). + +In order to run the model, first download the corresponding data by running the shell script in ./data. +Then you can run the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu). + +$python gan_trainer_image.py -d cifar --useGpu 1 + +The generated images will be stored in ./cifar_samples/ \ No newline at end of file diff --git a/demo/gan/gan_conf.py b/demo/gan/gan_conf.py index e9e3d2f07d..6bd68727ba 100644 --- a/demo/gan/gan_conf.py +++ b/demo/gan/gan_conf.py @@ -25,8 +25,11 @@ is_generator = mode == "generator" is_discriminator = mode == "discriminator" print('mode=%s' % mode) -noise_dim = 10 +# the dim of the noise (z) as the input of the generator network +noise_dim = 10 +# the dim of the hidden layer hidden_dim = 15 +# the dim of the generated sample sample_dim = 2 settings( @@ -123,7 +126,6 @@ if is_generator_training or is_discriminator_training: classification_error_evaluator(input=prob, label=label, name=mode+'_error') outputs(cost) - if is_generator: noise = data_layer(name="noise", size=noise_dim) outputs(generator(noise)) diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index 5d42f3238c..dc5910e9f0 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -25,8 +25,14 @@ is_discriminator_training = mode == "discriminator_training" is_generator = mode == "generator" is_discriminator = mode == "discriminator" +# The network structure below follows the dcgan paper +# (https://arxiv.org/abs/1511.06434) + print('mode=%s' % mode) +# the dim of the noise (z) as the input of the generator network noise_dim = 100 +# the number of filters in the layer in generator/discriminator that is +# closet to the image gf_dim = 64 df_dim = 64 if dataSource == "mnist": @@ -47,6 +53,19 @@ settings( def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name, param_attr, bias_attr, param_attr_bn, bn, trans=False, act=ReluActivation()): + + """ + conv_bn is a utility function that constructs a convolution/deconv layer + with an optional batch_norm layer + + :param bn: whether to use batch_norm_layer + :type bn: bool + :param trans: whether to use conv (False) or deconv (True) + :type trans: bool + """ + + # calculate the filter_size and padding size based on the given + # imgSize and ouput size tmp = imgSize - (output_x - 1) * stride if tmp <= 1 or tmp > 5: raise ValueError("conv input-output dimension does not fit") @@ -240,7 +259,6 @@ if is_generator_training or is_discriminator_training: classification_error_evaluator(input=prob, label=label, name=mode+'_error') outputs(cost) - if is_generator: noise = data_layer(name="noise", size=noise_dim) outputs(generator(noise)) diff --git a/demo/gan/gan_trainer.py b/demo/gan/gan_trainer.py index 6385bae011..063a98acdc 100644 --- a/demo/gan/gan_trainer.py +++ b/demo/gan/gan_trainer.py @@ -13,20 +13,22 @@ # limitations under the License. import argparse -import itertools import random import numpy +import cPickle +import sys,os +from PIL import Image from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import logger import py_paddle.swig_paddle as api -from py_paddle import DataProviderConverter - import matplotlib.pyplot as plt - def plot2DScatter(data, outputfile): - # Generate some test data + ''' + Plot the data as a 2D scatter plot and save to outputfile + data needs to be two dimensinoal + ''' x = data[:, 0] y = data[:, 1] print "The mean vector is %s" % numpy.mean(data, 0) @@ -37,14 +39,19 @@ def plot2DScatter(data, outputfile): plt.clf() plt.scatter(x, y) - # plt.show() plt.savefig(outputfile, bbox_inches='tight') def CHECK_EQ(a, b): assert a == b, "a=%s, b=%s" % (a, b) - def copy_shared_parameters(src, dst): + ''' + copy the parameters from src to dst + :param src: the source of the parameters + :type src: GradientMachine + :param dst: the destination of the parameters + :type dst: GradientMachine + ''' src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())] src_params = dict([(p.getName(), p) for p in src_params]) @@ -69,14 +76,77 @@ def print_parameters(src): for p in src_params: print "Name is %s" % p.getName() print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray() - -def get_real_samples(batch_size, sample_dim): - return numpy.random.rand(batch_size, sample_dim).astype('float32') - # return numpy.random.normal(loc=100.0, scale=100.0, size=(batch_size, sample_dim)).astype('float32') -def get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim): - gen_inputs = prepare_generator_data_batch(batch_size, noise_dim) - gen_inputs.resize(1) +def load_mnist_data(imageFile): + f = open(imageFile, "rb") + f.read(16) + + # Define number of samples for train/test + if "train" in imageFile: + n = 60000 + else: + n = 10000 + + data = numpy.zeros((n, 28*28), dtype = "float32") + + for i in range(n): + pixels = [] + for j in range(28 * 28): + pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0) + data[i, :] = pixels + + f.close() + return data + +def load_cifar_data(cifar_path): + batch_size = 10000 + data = numpy.zeros((5*batch_size, 32*32*3), dtype = "float32") + for i in range(1, 6): + file = cifar_path + "/data_batch_" + str(i) + fo = open(file, 'rb') + dict = cPickle.load(fo) + fo.close() + data[(i - 1)*batch_size:(i*batch_size), :] = dict["data"] + + data = data / 255.0 * 2.0 - 1.0 + return data + +# synthesize 2-D uniform data +def load_uniform_data(): + data = numpy.random.rand(1000000, 2).astype('float32') + return data + +def merge(images, size): + if images.shape[1] == 28*28: + h, w, c = 28, 28, 1 + else: + h, w, c = 32, 32, 3 + img = numpy.zeros((h * size[0], w * size[1], c)) + for idx in xrange(size[0] * size[1]): + i = idx % size[1] + j = idx // size[1] + img[j*h:j*h+h, i*w:i*w+w, :] = \ + ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) + return img.astype('uint8') + +def saveImages(images, path): + merged_img = merge(images, [8, 8]) + if merged_img.shape[2] == 1: + im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') + else: + im = Image.fromarray(merged_img, mode="RGB") + im.save(path) + +def get_real_samples(batch_size, data_np): + return data_np[numpy.random.choice(data_np.shape[0], batch_size, + replace=False),:] + +def get_noise(batch_size, noise_dim): + return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') + +def get_fake_samples(generator_machine, batch_size, noise): + gen_inputs = api.Arguments.createArguments(1) + gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) gen_outputs = api.Arguments.createArguments(0) generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() @@ -88,41 +158,27 @@ def get_training_loss(training_machine, inputs): loss = outputs.getSlotValue(0).copyToNumpyMat() return numpy.mean(loss) -def prepare_discriminator_data_batch( - generator_machine, batch_size, noise_dim, sample_dim): - fake_samples = get_fake_samples(generator_machine, batch_size / 2, noise_dim, sample_dim) - real_samples = get_real_samples(batch_size / 2, sample_dim) - all_samples = numpy.concatenate((fake_samples, real_samples), 0) - all_labels = numpy.concatenate( - (numpy.zeros(batch_size / 2, dtype='int32'), - numpy.ones(batch_size / 2, dtype='int32')), 0) - inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(all_samples)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumy(all_labels)) - return inputs - -def prepare_discriminator_data_batch_pos(batch_size, noise_dim, sample_dim): - real_samples = get_real_samples(batch_size, sample_dim) +def prepare_discriminator_data_batch_pos(batch_size, data_np): + real_samples = get_real_samples(batch_size, data_np) labels = numpy.ones(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(real_samples)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels)) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) return inputs -def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_dim, sample_dim): - fake_samples = get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim) +def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): + fake_samples = get_fake_samples(generator_machine, batch_size, noise) labels = numpy.zeros(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(fake_samples)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(labels)) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) return inputs -def prepare_generator_data_batch(batch_size, dim): - noise = numpy.random.normal(size=(batch_size, dim)).astype('float32') +def prepare_generator_data_batch(batch_size, noise): label = numpy.ones(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createGpuDenseFromNumpy(noise)) - inputs.setSlotIds(1, api.IVector.createGpuVectorFromNumpy(label)) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label)) return inputs @@ -140,19 +196,48 @@ def get_layer_size(model_conf, layer_name): def main(): - api.initPaddle('--use_gpu=1', '--dot_period=10', '--log_period=100', - '--gpu_id=2') - gen_conf = parse_config("gan_conf.py", "mode=generator_training") - dis_conf = parse_config("gan_conf.py", "mode=discriminator_training") - generator_conf = parse_config("gan_conf.py", "mode=generator") + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--dataSource", help="mnist or cifar or uniform") + parser.add_argument("--useGpu", default="1", + help="1 means use gpu for training") + parser.add_argument("--gpuId", default="0", + help="the gpu_id parameter") + args = parser.parse_args() + dataSource = args.dataSource + useGpu = args.useGpu + assert dataSource in ["mnist", "cifar", "uniform"] + assert useGpu in ["0", "1"] + + api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', + '--gpu_id=' + args.gpuId) + + if dataSource == "uniform": + conf = "gan_conf.py" + num_iter = 10000 + else: + conf = "gan_conf_image.py" + num_iter = 1000 + + gen_conf = parse_config(conf, "mode=generator_training,data=" + dataSource) + dis_conf = parse_config(conf, "mode=discriminator_training,data=" + dataSource) + generator_conf = parse_config(conf, "mode=generator,data=" + dataSource) batch_size = dis_conf.opt_config.batch_size noise_dim = get_layer_size(gen_conf.model_config, "noise") - sample_dim = get_layer_size(dis_conf.model_config, "sample") - + + if dataSource == "mnist": + data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte") + elif dataSource == "cifar": + data_np = load_cifar_data("./data/cifar-10-batches-py/") + else: + data_np = load_uniform_data() + + if not os.path.exists("./%s_samples/" % dataSource): + os.makedirs("./%s_samples/" % dataSource) + # this create a gradient machine for discriminator dis_training_machine = api.GradientMachine.createFromConfigProto( dis_conf.model_config) - + # this create a gradient machine for generator gen_training_machine = api.GradientMachine.createFromConfigProto( gen_conf.model_config) @@ -161,57 +246,64 @@ def main(): logger.info(str(generator_conf.model_config)) generator_machine = api.GradientMachine.createFromConfigProto( generator_conf.model_config) - + dis_trainer = api.Trainer.create( dis_conf, dis_training_machine) gen_trainer = api.Trainer.create( gen_conf, gen_training_machine) - + dis_trainer.startTrain() gen_trainer.startTrain() + + # Sync parameters between networks (GradientMachine) at the beginning copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine, generator_machine) + + # constrain that either discriminator or generator can not be trained + # consecutively more than MAX_strike times curr_train = "dis" curr_strike = 0 MAX_strike = 5 - + for train_pass in xrange(100): dis_trainer.startTrainPass() gen_trainer.startTrainPass() - for i in xrange(1000): -# data_batch_dis = prepare_discriminator_data_batch( -# generator_machine, batch_size, noise_dim, sample_dim) -# dis_loss = get_training_loss(dis_training_machine, data_batch_dis) + for i in xrange(num_iter): + # Do forward pass in discriminator to get the dis_loss + noise = get_noise(batch_size, noise_dim) data_batch_dis_pos = prepare_discriminator_data_batch_pos( - batch_size, noise_dim, sample_dim) + batch_size, data_np) dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) data_batch_dis_neg = prepare_discriminator_data_batch_neg( - generator_machine, batch_size, noise_dim, sample_dim) + generator_machine, batch_size, noise) dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) - + dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 + # Do forward pass in generator to get the gen_loss data_batch_gen = prepare_generator_data_batch( - batch_size, noise_dim) + batch_size, noise) gen_loss = get_training_loss(gen_training_machine, data_batch_gen) - - if i % 1000 == 0: + + if i % 100 == 0: + print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) - - if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): + + # Decide which network to train based on the training history + # And the relative size of the loss + if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \ + ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): if curr_train == "dis": curr_strike += 1 else: curr_train = "dis" curr_strike = 1 dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) - dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) -# dis_loss = numpy.mean(dis_trainer.getForwardOutput()[0]["value"]) -# print "getForwardOutput loss is %s" % dis_loss + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) copy_shared_parameters(dis_training_machine, gen_training_machine) - + else: if curr_train == "gen": curr_strike += 1 @@ -221,12 +313,15 @@ def main(): gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine, generator_machine) - + dis_trainer.finishTrainPass() gen_trainer.finishTrainPass() - - fake_samples = get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim) - plot2DScatter(fake_samples, "./train_pass%s.png" % train_pass) + # At the end of each pass, save the generated samples/images + fake_samples = get_fake_samples(generator_machine, batch_size, noise) + if dataSource == "uniform": + plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) + else: + saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) dis_trainer.finishTrain() gen_trainer.finishTrain() diff --git a/demo/gan/gan_trainer_image.py b/demo/gan/gan_trainer_image.py deleted file mode 100644 index b4062b213e..0000000000 --- a/demo/gan/gan_trainer_image.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) 2016 Baidu, Inc. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import random -import numpy -import cPickle -import sys,os -from PIL import Image - -from paddle.trainer.config_parser import parse_config -from paddle.trainer.config_parser import logger -import py_paddle.swig_paddle as api -import matplotlib.pyplot as plt - -def plot2DScatter(data, outputfile): - x = data[:, 0] - y = data[:, 1] - print "The mean vector is %s" % numpy.mean(data, 0) - print "The std vector is %s" % numpy.std(data, 0) - - heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) - extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] - - plt.clf() - plt.scatter(x, y) - plt.savefig(outputfile, bbox_inches='tight') - -def CHECK_EQ(a, b): - assert a == b, "a=%s, b=%s" % (a, b) - - -def copy_shared_parameters(src, dst): - src_params = [src.getParameter(i) - for i in xrange(src.getParameterSize())] - src_params = dict([(p.getName(), p) for p in src_params]) - - - for i in xrange(dst.getParameterSize()): - dst_param = dst.getParameter(i) - src_param = src_params.get(dst_param.getName(), None) - if src_param is None: - continue - src_value = src_param.getBuf(api.PARAMETER_VALUE) - dst_value = dst_param.getBuf(api.PARAMETER_VALUE) - CHECK_EQ(len(src_value), len(dst_value)) - dst_value.copyFrom(src_value) - dst_param.setValueUpdated() - -def print_parameters(src): - src_params = [src.getParameter(i) - for i in xrange(src.getParameterSize())] - - print "***************" - for p in src_params: - print "Name is %s" % p.getName() - print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray() - -def load_mnist_data(imageFile): - f = open(imageFile, "rb") - f.read(16) - - # Define number of samples for train/test - if "train" in imageFile: - n = 60000 - else: - n = 10000 - - data = numpy.zeros((n, 28*28), dtype = "float32") - - for i in range(n): - pixels = [] - for j in range(28 * 28): - pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0) - data[i, :] = pixels - - f.close() - return data - -def load_cifar_data(cifar_path): - batch_size = 10000 - data = numpy.zeros((5*batch_size, 32*32*3), dtype = "float32") - for i in range(1, 6): - file = cifar_path + "/data_batch_" + str(i) - fo = open(file, 'rb') - dict = cPickle.load(fo) - fo.close() - data[(i - 1)*batch_size:(i*batch_size), :] = dict["data"] - - data = data / 255.0 * 2.0 - 1.0 - return data - -# synthesize 2-D uniform data -def load_uniform_data(): - data = numpy.random.rand(1000000, 2).astype('float32') - return data - -def merge(images, size): - if images.shape[1] == 28*28: - h, w, c = 28, 28, 1 - else: - h, w, c = 32, 32, 3 - img = numpy.zeros((h * size[0], w * size[1], c)) - for idx in xrange(size[0] * size[1]): - i = idx % size[1] - j = idx // size[1] - img[j*h:j*h+h, i*w:i*w+w, :] = \ - ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) - return img.astype('uint8') - -def saveImages(images, path): - merged_img = merge(images, [8, 8]) - if merged_img.shape[2] == 1: - im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') - else: - im = Image.fromarray(merged_img, mode="RGB") - im.save(path) - -def get_real_samples(batch_size, data_np): - return data_np[numpy.random.choice(data_np.shape[0], batch_size, - replace=False),:] - -def get_noise(batch_size, noise_dim): - return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') - -def get_fake_samples(generator_machine, batch_size, noise): - gen_inputs = api.Arguments.createArguments(1) - gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) - gen_outputs = api.Arguments.createArguments(0) - generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) - fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() - return fake_samples - -def get_training_loss(training_machine, inputs): - outputs = api.Arguments.createArguments(0) - training_machine.forward(inputs, outputs, api.PASS_TEST) - loss = outputs.getSlotValue(0).copyToNumpyMat() - return numpy.mean(loss) - -def prepare_discriminator_data_batch_pos(batch_size, data_np): - real_samples = get_real_samples(batch_size, data_np) - labels = numpy.ones(batch_size, dtype='int32') - inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples)) - inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) - return inputs - -def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): - fake_samples = get_fake_samples(generator_machine, batch_size, noise) - labels = numpy.zeros(batch_size, dtype='int32') - inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples)) - inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) - return inputs - -def prepare_generator_data_batch(batch_size, noise): - label = numpy.ones(batch_size, dtype='int32') - inputs = api.Arguments.createArguments(2) - inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) - inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label)) - return inputs - - -def find(iterable, cond): - for item in iterable: - if cond(item): - return item - return None - - -def get_layer_size(model_conf, layer_name): - layer_conf = find(model_conf.layers, lambda x: x.name == layer_name) - assert layer_conf is not None, "Cannot find '%s' layer" % layer_name - return layer_conf.size - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--dataSource", help="mnist or cifar or uniform") - parser.add_argument("--useGpu", default="1", - help="1 means use gpu for training") - parser.add_argument("--gpuId", default="0", - help="the gpu_id parameter") - args = parser.parse_args() - dataSource = args.dataSource - useGpu = args.useGpu - assert dataSource in ["mnist", "cifar", "uniform"] - assert useGpu in ["0", "1"] - - api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100', - '--gpu_id=' + args.gpuId) - - if dataSource == "uniform": - conf = "gan_conf.py" - num_iter = 10000 - else: - conf = "gan_conf_image.py" - num_iter = 1000 - - gen_conf = parse_config(conf, "mode=generator_training,data=" + dataSource) - dis_conf = parse_config(conf, "mode=discriminator_training,data=" + dataSource) - generator_conf = parse_config(conf, "mode=generator,data=" + dataSource) - batch_size = dis_conf.opt_config.batch_size - noise_dim = get_layer_size(gen_conf.model_config, "noise") - - if dataSource == "mnist": - data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte") - elif dataSource == "cifar": - data_np = load_cifar_data("./data/cifar-10-batches-py/") - else: - data_np = load_uniform_data() - - if not os.path.exists("./%s_samples/" % dataSource): - os.makedirs("./%s_samples/" % dataSource) - - # this create a gradient machine for discriminator - dis_training_machine = api.GradientMachine.createFromConfigProto( - dis_conf.model_config) - - gen_training_machine = api.GradientMachine.createFromConfigProto( - gen_conf.model_config) - - # generator_machine is used to generate data only, which is used for - # training discrinator - logger.info(str(generator_conf.model_config)) - generator_machine = api.GradientMachine.createFromConfigProto( - generator_conf.model_config) - - dis_trainer = api.Trainer.create( - dis_conf, dis_training_machine) - - gen_trainer = api.Trainer.create( - gen_conf, gen_training_machine) - - dis_trainer.startTrain() - gen_trainer.startTrain() - - copy_shared_parameters(gen_training_machine, dis_training_machine) - copy_shared_parameters(gen_training_machine, generator_machine) - - # constrain that either discriminator or generator can not be trained - # consecutively more than MAX_strike times - curr_train = "dis" - curr_strike = 0 - MAX_strike = 5 - - for train_pass in xrange(100): - dis_trainer.startTrainPass() - gen_trainer.startTrainPass() - for i in xrange(num_iter): - noise = get_noise(batch_size, noise_dim) - data_batch_dis_pos = prepare_discriminator_data_batch_pos( - batch_size, data_np) - dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) - - data_batch_dis_neg = prepare_discriminator_data_batch_neg( - generator_machine, batch_size, noise) - dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) - - dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 - - data_batch_gen = prepare_generator_data_batch( - batch_size, noise) - gen_loss = get_training_loss(gen_training_machine, data_batch_gen) - - if i % 100 == 0: - print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) - print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) - - if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \ - ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): - if curr_train == "dis": - curr_strike += 1 - else: - curr_train = "dis" - curr_strike = 1 - dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) - dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) - copy_shared_parameters(dis_training_machine, gen_training_machine) - - else: - if curr_train == "gen": - curr_strike += 1 - else: - curr_train = "gen" - curr_strike = 1 - gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) - copy_shared_parameters(gen_training_machine, dis_training_machine) - copy_shared_parameters(gen_training_machine, generator_machine) - - dis_trainer.finishTrainPass() - gen_trainer.finishTrainPass() - fake_samples = get_fake_samples(generator_machine, batch_size, noise) - if dataSource == "uniform": - plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) - else: - saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass)) - dis_trainer.finishTrain() - gen_trainer.finishTrain() - -if __name__ == '__main__': - main() -- GitLab