# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import random import numpy import cPickle import sys, os from PIL import Image from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import logger import py_paddle.swig_paddle as api import matplotlib.pyplot as plt def CHECK_EQ(a, b): assert a == b, "a=%s, b=%s" % (a, b) def copy_shared_parameters(src, dst): ''' copy the parameters from src to dst :param src: the source of the parameters :type src: GradientMachine :param dst: the destination of the parameters :type dst: GradientMachine ''' src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())] src_params = dict([(p.getName(), p) for p in src_params]) for i in xrange(dst.getParameterSize()): dst_param = dst.getParameter(i) src_param = src_params.get(dst_param.getName(), None) if src_param is None: continue src_value = src_param.getBuf(api.PARAMETER_VALUE) dst_value = dst_param.getBuf(api.PARAMETER_VALUE) CHECK_EQ(len(src_value), len(dst_value)) dst_value.copyFrom(src_value) dst_param.setValueUpdated() def print_parameters(src): src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())] print "***************" for p in src_params: print "Name is %s" % p.getName() print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray( ) # synthesize 2-D uniform data def load_uniform_data(): data = numpy.random.rand(1000000, 2).astype('float32') return data def load_mnist_data(imageFile): f = open(imageFile, "rb") f.read(16) # Define number of samples for train/test if "train" in imageFile: n = 60000 else: n = 10000 data = numpy.fromfile(f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)) data = data / 255.0 * 2.0 - 1.0 f.close() return data.astype('float32') def load_cifar_data(cifar_path): batch_size = 10000 data = numpy.zeros((5 * batch_size, 32 * 32 * 3), dtype="float32") for i in range(1, 6): file = cifar_path + "/data_batch_" + str(i) fo = open(file, 'rb') dict = cPickle.load(fo) fo.close() data[(i - 1) * batch_size:(i * batch_size), :] = dict["data"] data = data / 255.0 * 2.0 - 1.0 return data def plot2DScatter(data, outputfile): ''' Plot the data as a 2D scatter plot and save to outputfile data needs to be two dimensinoal ''' x = data[:, 0] y = data[:, 1] logger.info("The mean vector is %s" % numpy.mean(data, 0)) logger.info("The std vector is %s" % numpy.std(data, 0)) heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] plt.clf() plt.scatter(x, y) plt.savefig(outputfile, bbox_inches='tight') def merge(images, size): if images.shape[1] == 28 * 28: h, w, c = 28, 28, 1 else: h, w, c = 32, 32, 3 img = numpy.zeros((h * size[0], w * size[1], c)) for idx in xrange(size[0] * size[1]): i = idx % size[1] j = idx // size[1] img[j*h:j*h+h, i*w:i*w+w, :] = \ ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) return img.astype('uint8') def save_images(images, path): merged_img = merge(images, [8, 8]) if merged_img.shape[2] == 1: im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') else: im = Image.fromarray(merged_img, mode="RGB") im.save(path) def save_results(samples, path, data_source): if data_source == "uniform": plot2DScatter(samples, path) else: save_images(samples, path) def get_real_samples(batch_size, data_np): return data_np[numpy.random.choice( data_np.shape[0], batch_size, replace=False), :] def get_noise(batch_size, noise_dim): return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') def get_fake_samples(generator_machine, batch_size, noise): gen_inputs = api.Arguments.createArguments(1) gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) gen_outputs = api.Arguments.createArguments(0) generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() return fake_samples def get_training_loss(training_machine, inputs): outputs = api.Arguments.createArguments(0) training_machine.forward(inputs, outputs, api.PASS_TEST) loss = outputs.getSlotValue(0).copyToNumpyMat() return numpy.mean(loss) def prepare_discriminator_data_batch_pos(batch_size, data_np): real_samples = get_real_samples(batch_size, data_np) labels = numpy.ones(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2) inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) return inputs def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): fake_samples = get_fake_samples(generator_machine, batch_size, noise) labels = numpy.zeros(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2) inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) return inputs def prepare_generator_data_batch(batch_size, noise): label = numpy.ones(batch_size, dtype='int32') inputs = api.Arguments.createArguments(2) inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label)) return inputs def find(iterable, cond): for item in iterable: if cond(item): return item return None def get_layer_size(model_conf, layer_name): layer_conf = find(model_conf.layers, lambda x: x.name == layer_name) assert layer_conf is not None, "Cannot find '%s' layer" % layer_name return layer_conf.size def main(): parser = argparse.ArgumentParser() parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform") parser.add_argument( "--use_gpu", default="1", help="1 means use gpu for training") parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter") parser.add_argument( "--model_dir", default="", help="model path for generating samples, empty means training mode") args = parser.parse_args() data_source = args.data_source use_gpu = args.use_gpu model_dir = args.model_dir assert data_source in ["mnist", "cifar", "uniform"] assert use_gpu in ["0", "1"] if not os.path.exists("./%s_samples/" % data_source): os.makedirs("./%s_samples/" % data_source) if not os.path.exists("./%s_params/" % data_source): os.makedirs("./%s_params/" % data_source) api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', '--log_period=100', '--gpu_id=' + args.gpu_id, '--save_dir=' + "./%s_params/" % data_source) if data_source == "uniform": conf = "gan_conf.py" num_iter = 10000 else: conf = "gan_conf_image.py" num_iter = 1000 gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source) dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source) generator_conf = parse_config(conf, "mode=generator,data=" + data_source) logger.info(str(generator_conf.model_config)) batch_size = dis_conf.opt_config.batch_size noise_dim = get_layer_size(gen_conf.model_config, "noise") if data_source == "mnist": data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte") elif data_source == "cifar": data_np = load_cifar_data("./data/cifar-10-batches-py/") else: data_np = load_uniform_data() # this creates a gradient machine for discriminator dis_training_machine = api.GradientMachine.createFromConfigProto( dis_conf.model_config) # this create a gradient machine for generator gen_training_machine = api.GradientMachine.createFromConfigProto( gen_conf.model_config) # generator_machine is used to generate data only, which is used for # training discriminator generator_machine = api.GradientMachine.createFromConfigProto( generator_conf.model_config) # In the generating settings, use previously trained model to generate # fake samples if model_dir != "": generator_machine.loadParameters(model_dir) noise = get_noise(batch_size, noise_dim) fake_samples = get_fake_samples(generator_machine, batch_size, noise) save_results(fake_samples, "./generated_samples.png", data_source) return dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) dis_trainer.startTrain() gen_trainer.startTrain() # Sync parameters between networks (GradientMachine) at the beginning copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine, generator_machine) # constrain that either discriminator or generator can not be trained # consecutively more than MAX_strike times curr_train = "dis" curr_strike = 0 MAX_strike = 5 for train_pass in xrange(100): dis_trainer.startTrainPass() gen_trainer.startTrainPass() for i in xrange(num_iter): # Do forward pass in discriminator to get the dis_loss noise = get_noise(batch_size, noise_dim) data_batch_dis_pos = prepare_discriminator_data_batch_pos( batch_size, data_np) dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) data_batch_dis_neg = prepare_discriminator_data_batch_neg( generator_machine, batch_size, noise) dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 # Do forward pass in generator to get the gen_loss data_batch_gen = prepare_generator_data_batch(batch_size, noise) gen_loss = get_training_loss(gen_training_machine, data_batch_gen) if i % 100 == 0: print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) # Decide which network to train based on the training history # And the relative size of the loss if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \ ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): if curr_train == "dis": curr_strike += 1 else: curr_train = "dis" curr_strike = 1 dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) copy_shared_parameters(dis_training_machine, gen_training_machine) else: if curr_train == "gen": curr_strike += 1 else: curr_train = "gen" curr_strike = 1 gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine, generator_machine) dis_trainer.finishTrainPass() gen_trainer.finishTrainPass() # At the end of each pass, save the generated samples/images fake_samples = get_fake_samples(generator_machine, batch_size, noise) save_results(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass), data_source) dis_trainer.finishTrain() gen_trainer.finishTrain() if __name__ == '__main__': main()