From c159e4dd78fa58fb4d8e23b4ab2210685530f36d Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Thu, 3 Nov 2016 17:14:22 -0700 Subject: [PATCH] added get_mnist_data and demo/gan and updated the gan_conf and gan_trainer python files --- .gitignore | 1 + demo/gan/.gitignore | 3 +- demo/gan/data/get_mnist_data.sh | 19 +++++ demo/gan/gan_conf.py | 58 +++++++++++++-- demo/gan/gan_conf_image.py | 6 +- demo/gan/gan_trainer.py | 120 ++++++++++++++++++++++++++++---- 6 files changed, 182 insertions(+), 25 deletions(-) create mode 100644 demo/gan/data/get_mnist_data.sh diff --git a/.gitignore b/.gitignore index ee8489c1d7..35bed0accd 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build/ .cproject .pydevproject Makefile +.test_env/ diff --git a/demo/gan/.gitignore b/demo/gan/.gitignore index 828646b136..91ac27fe63 100644 --- a/demo/gan/.gitignore +++ b/demo/gan/.gitignore @@ -2,6 +2,5 @@ output/ *.png .pydevproject .project -data/ trainLog.txt - +data/raw_data/ diff --git a/demo/gan/data/get_mnist_data.sh b/demo/gan/data/get_mnist_data.sh new file mode 100644 index 0000000000..3a6aa51322 --- /dev/null +++ b/demo/gan/data/get_mnist_data.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env sh +# This scripts downloads the mnist data and unzips it. +set -e +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +rm -rf "$DIR/raw_data" +mkdir "$DIR/raw_data" +cd "$DIR/raw_data" + +echo "Downloading..." + +for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte +do + if [ ! -e $fname ]; then + wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz + gunzip ${fname}.gz + fi +done + + diff --git a/demo/gan/gan_conf.py b/demo/gan/gan_conf.py index df85cd70c7..5b4a2bbf8b 100644 --- a/demo/gan/gan_conf.py +++ b/demo/gan/gan_conf.py @@ -26,11 +26,12 @@ is_discriminator = mode == "discriminator" print('mode=%s' % mode) noise_dim = 10 +hidden_dim = 15 sample_dim = 2 settings( - batch_size=100, - learning_rate=1e-2, + batch_size=128, + learning_rate=1e-4, learning_method=AdamOptimizer() ) @@ -44,9 +45,30 @@ def discriminator(sample): """ param_attr = ParamAttr(is_static=is_generator_training) bias_attr = ParamAttr(is_static=is_generator_training, - initial_mean=0, + initial_mean=1.0, initial_std=0) - return fc_layer(input=sample, name="dis_prob", size=2, + hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=ReluActivation()) + #act=LinearActivation()) + + hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + #act=ReluActivation()) + act=LinearActivation()) + + hidden_bn = batch_norm_layer(hidden2, + act=ReluActivation(), + name="dis_hidden_bn", + bias_attr=bias_attr, + param_attr=ParamAttr(is_static=is_generator_training, + initial_mean=1.0, + initial_std=0.02), + use_global_stats=False) + + return fc_layer(input=hidden_bn, name="dis_prob", size=2, bias_attr=bias_attr, param_attr=param_attr, act=SoftmaxActivation()) @@ -57,9 +79,33 @@ def generator(noise): """ param_attr = ParamAttr(is_static=is_discriminator_training) bias_attr = ParamAttr(is_static=is_discriminator_training, - initial_mean=0, + initial_mean=1.0, initial_std=0) - return fc_layer(input=noise, + + hidden = fc_layer(input=noise, + name="gen_layer_hidden", + size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=ReluActivation()) + #act=LinearActivation()) + + hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + #act=ReluActivation()) + act=LinearActivation()) + + hidden_bn = batch_norm_layer(hidden2, + act=ReluActivation(), + name="gen_layer_hidden_bn", + bias_attr=bias_attr, + param_attr=ParamAttr(is_static=is_discriminator_training, + initial_mean=1.0, + initial_std=0.02), + use_global_stats=False) + + return fc_layer(input=hidden_bn, name="gen_layer1", size=sample_dim, bias_attr=bias_attr, diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py index 58bed2b189..9a4f2a4ea4 100644 --- a/demo/gan/gan_conf_image.py +++ b/demo/gan/gan_conf_image.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from paddle.trainer_config_helpers import * -from paddle.trainer_config_helpers.layers import img_convTrans_layer from paddle.trainer_config_helpers.activations import LinearActivation from numpy.distutils.system_info import tmp @@ -55,13 +54,14 @@ def convTrans_bn(input, channels, output_x, num_filters, imgSize, stride, name, padding = 0 - convTrans = img_convTrans_layer(input, filter_size=filter_size, + convTrans = img_conv_layer(input, filter_size=filter_size, num_filters=num_filters, name=name + "_convt", num_channels=channels, act=LinearActivation(), groups=1, stride=stride, padding=padding, bias_attr=bias_attr, param_attr=param_attr, shared_biases=True, layer_attr=None, - filter_size_y=None, stride_y=None, padding_y=None) + filter_size_y=None, stride_y=None, padding_y=None, + trans=True) convTrans_bn = batch_norm_layer(convTrans, act=ReluActivation(), diff --git a/demo/gan/gan_trainer.py b/demo/gan/gan_trainer.py index ae13b3b27d..e64f0ffa0d 100644 --- a/demo/gan/gan_trainer.py +++ b/demo/gan/gan_trainer.py @@ -22,6 +22,23 @@ from paddle.trainer.config_parser import logger import py_paddle.swig_paddle as api from py_paddle import DataProviderConverter +import matplotlib.pyplot as plt + + +def plot2DScatter(data, outputfile): + # Generate some test data + x = data[:, 0] + y = data[:, 1] + print "The mean vector is %s" % numpy.mean(data, 0) + print "The std vector is %s" % numpy.std(data, 0) + + heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) + extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] + + plt.clf() + plt.scatter(x, y) + # plt.show() + plt.savefig(outputfile, bbox_inches='tight') def CHECK_EQ(a, b): assert a == b, "a=%s, b=%s" % (a, b) @@ -32,6 +49,7 @@ def copy_shared_parameters(src, dst): for i in xrange(src.getParameterSize())] src_params = dict([(p.getName(), p) for p in src_params]) + for i in xrange(dst.getParameterSize()): dst_param = dst.getParameter(i) src_param = src_params.get(dst_param.getName(), None) @@ -42,19 +60,37 @@ def copy_shared_parameters(src, dst): CHECK_EQ(len(src_value), len(dst_value)) dst_value.copyFrom(src_value) dst_param.setValueUpdated() + +def print_parameters(src): + src_params = [src.getParameter(i) + for i in xrange(src.getParameterSize())] - + print "***************" + for p in src_params: + print "Name is %s" % p.getName() + print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray() + def get_real_samples(batch_size, sample_dim): - return numpy.random.rand(batch_size, sample_dim).astype('float32') - + return numpy.random.rand(batch_size, sample_dim).astype('float32') * 10.0 - 10.0 + # return numpy.random.normal(loc=100.0, scale=100.0, size=(batch_size, sample_dim)).astype('float32') -def prepare_discriminator_data_batch( - generator_machine, batch_size, noise_dim, sample_dim): - gen_inputs = prepare_generator_data_batch(batch_size / 2, noise_dim) +def get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim): + gen_inputs = prepare_generator_data_batch(batch_size, noise_dim) gen_inputs.resize(1) gen_outputs = api.Arguments.createArguments(0) generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() + return fake_samples + +def get_training_loss(training_machine, inputs): + outputs = api.Arguments.createArguments(0) + training_machine.forward(inputs, outputs, api.PASS_TEST) + loss = outputs.getSlotValue(0).copyToNumpyMat() + return numpy.mean(loss) + +def prepare_discriminator_data_batch( + generator_machine, batch_size, noise_dim, sample_dim): + fake_samples = get_fake_samples(generator_machine, batch_size / 2, noise_dim, sample_dim) real_samples = get_real_samples(batch_size / 2, sample_dim) all_samples = numpy.concatenate((fake_samples, real_samples), 0) all_labels = numpy.concatenate( @@ -65,6 +101,21 @@ def prepare_discriminator_data_batch( inputs.setSlotIds(1, api.IVector.createCpuVectorFromNumpy(all_labels)) return inputs +def prepare_discriminator_data_batch_pos(batch_size, noise_dim, sample_dim): + real_samples = get_real_samples(batch_size, sample_dim) + labels = numpy.ones(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createCpuDenseFromNumpy(real_samples)) + inputs.setSlotIds(1, api.IVector.createCpuVectorFromNumpy(labels)) + return inputs + +def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise_dim, sample_dim): + fake_samples = get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim) + labels = numpy.zeros(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createCpuDenseFromNumpy(fake_samples)) + inputs.setSlotIds(1, api.IVector.createCpuVectorFromNumpy(labels)) + return inputs def prepare_generator_data_batch(batch_size, dim): noise = numpy.random.normal(size=(batch_size, dim)).astype('float32') @@ -118,22 +169,63 @@ def main(): dis_trainer.startTrain() gen_trainer.startTrain() + copy_shared_parameters(gen_training_machine, dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) + curr_train = "dis" + curr_strike = 0 + MAX_strike = 5 + for train_pass in xrange(10): dis_trainer.startTrainPass() gen_trainer.startTrainPass() for i in xrange(100000): - copy_shared_parameters(gen_training_machine, generator_machine) - copy_shared_parameters(gen_training_machine, dis_training_machine) - data_batch = prepare_discriminator_data_batch( +# data_batch_dis = prepare_discriminator_data_batch( +# generator_machine, batch_size, noise_dim, sample_dim) +# dis_loss = get_training_loss(dis_training_machine, data_batch_dis) + data_batch_dis_pos = prepare_discriminator_data_batch_pos( + batch_size, noise_dim, sample_dim) + dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) + + data_batch_dis_neg = prepare_discriminator_data_batch_neg( generator_machine, batch_size, noise_dim, sample_dim) - dis_trainer.trainOneDataBatch(batch_size, data_batch) + dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) + + dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 + + data_batch_gen = prepare_generator_data_batch( + batch_size, noise_dim) + gen_loss = get_training_loss(gen_training_machine, data_batch_gen) + + if i % 1000 == 0: + print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) + + if (not (curr_train == "dis" and curr_strike == MAX_strike)) and ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > 0.690 or dis_loss > gen_loss): + if curr_train == "dis": + curr_strike += 1 + else: + curr_train = "dis" + curr_strike = 1 + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) +# dis_loss = numpy.mean(dis_trainer.getForwardOutput()[0]["value"]) +# print "getForwardOutput loss is %s" % dis_loss + copy_shared_parameters(dis_training_machine, gen_training_machine) + + else: + if curr_train == "gen": + curr_strike += 1 + else: + curr_train = "gen" + curr_strike = 1 + gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) + copy_shared_parameters(gen_training_machine, dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) - copy_shared_parameters(dis_training_machine, gen_training_machine) - data_batch = prepare_generator_data_batch( - batch_size, noise_dim) - gen_trainer.trainOneDataBatch(batch_size, data_batch) dis_trainer.finishTrainPass() gen_trainer.finishTrainPass() + + fake_samples = get_fake_samples(generator_machine, batch_size, noise_dim, sample_dim) + plot2DScatter(fake_samples, "./train_pass%s.png" % train_pass) dis_trainer.finishTrain() gen_trainer.finishTrain() -- GitLab