diff --git a/.gitignore b/.gitignore index ee8489c1d71bd050b9a1d9358a664d2294165292..35bed0accdaa274f5966ca5b4b7180106325449b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build/ .cproject .pydevproject Makefile +.test_env/ diff --git a/demo/gan/.gitignore b/demo/gan/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..93a6f5080a16a601cffb0bff51af9aef3ba3bae7 --- /dev/null +++ b/demo/gan/.gitignore @@ -0,0 +1,11 @@ +output/ +uniform_params/ +cifar_params/ +mnist_params/ +*.png +.pydevproject +.project +*.log +*.pyc +data/mnist_data/ +data/cifar-10-batches-py/ diff --git a/demo/gan/README.md b/demo/gan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fdc970a07b488c3a4146c9baa76a133a456fc9ab --- /dev/null +++ b/demo/gan/README.md @@ -0,0 +1,13 @@ +# Generative Adversarial Networks (GAN) + +This demo implements GAN training described in the original GAN paper (https://arxiv.org/abs/1406.2661) and DCGAN (https://arxiv.org/abs/1511.06434). + +The general training procedures are implemented in gan_trainer.py. The neural network configurations are specified in gan_conf.py (for synthetic data) and gan_conf_image.py (for image data). + +In order to run the model, first download the corresponding data by running the shell script in ./data. +Then you can run the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu). + +$python gan_trainer.py -d cifar --use_gpu 1 + +The generated images will be stored in ./cifar_samples/ +The corresponding models will be stored in ./cifar_params/ \ No newline at end of file diff --git a/demo/gan/data/download_cifar.sh b/demo/gan/data/download_cifar.sh new file mode 100755 index 0000000000000000000000000000000000000000..ea3be594cd08f829e94f2c692a44947baa62b759 --- /dev/null +++ b/demo/gan/data/download_cifar.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +tar zxf cifar-10-python.tar.gz +rm cifar-10-python.tar.gz + diff --git a/demo/gan/data/get_mnist_data.sh b/demo/gan/data/get_mnist_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..d21bf7067135f1f8be486ef0f13fc3ec94ffc4ed --- /dev/null +++ b/demo/gan/data/get_mnist_data.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env sh +# This script downloads the mnist data and unzips it. +set -e +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +rm -rf "$DIR/mnist_data" +mkdir "$DIR/mnist_data" +cd "$DIR/mnist_data" + +echo "Downloading..." + +for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte +do + if [ ! -e $fname ]; then + wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz + gunzip ${fname}.gz + fi +done + + diff --git a/demo/gan/gan_conf.py b/demo/gan/gan_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..05eee3a9b9ce455eb3a5d47d3165ee7f42f1002e --- /dev/null +++ b/demo/gan/gan_conf.py @@ -0,0 +1,134 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.trainer_config_helpers import * + +mode = get_config_arg("mode", str, "generator") +assert mode in set(["generator", + "discriminator", + "generator_training", + "discriminator_training"]) + +is_generator_training = mode == "generator_training" +is_discriminator_training = mode == "discriminator_training" +is_generator = mode == "generator" +is_discriminator = mode == "discriminator" + +# The network structure below follows the ref https://arxiv.org/abs/1406.2661 +# Here we used two hidden layers and batch_norm + +print('mode=%s' % mode) +# the dim of the noise (z) as the input of the generator network +noise_dim = 10 +# the dim of the hidden layer +hidden_dim = 10 +# the dim of the generated sample +sample_dim = 2 + +settings( + batch_size=128, + learning_rate=1e-4, + learning_method=AdamOptimizer(beta1=0.5) +) + +def discriminator(sample): + """ + discriminator ouputs the probablity of a sample is from generator + or real data. + The output has two dimenstional: dimension 0 is the probablity + of the sample is from generator and dimension 1 is the probabblity + of the sample is from real data. + """ + param_attr = ParamAttr(is_static=is_generator_training) + bias_attr = ParamAttr(is_static=is_generator_training, + initial_mean=1.0, + initial_std=0) + + hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=ReluActivation()) + + hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + hidden_bn = batch_norm_layer(hidden2, + act=ReluActivation(), + name="dis_hidden_bn", + bias_attr=bias_attr, + param_attr=ParamAttr(is_static=is_generator_training, + initial_mean=1.0, + initial_std=0.02), + use_global_stats=False) + + return fc_layer(input=hidden_bn, name="dis_prob", size=2, + bias_attr=bias_attr, + param_attr=param_attr, + act=SoftmaxActivation()) + +def generator(noise): + """ + generator generates a sample given noise + """ + param_attr = ParamAttr(is_static=is_discriminator_training) + bias_attr = ParamAttr(is_static=is_discriminator_training, + initial_mean=1.0, + initial_std=0) + + hidden = fc_layer(input=noise, + name="gen_layer_hidden", + size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=ReluActivation()) + + hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + hidden_bn = batch_norm_layer(hidden2, + act=ReluActivation(), + name="gen_layer_hidden_bn", + bias_attr=bias_attr, + param_attr=ParamAttr(is_static=is_discriminator_training, + initial_mean=1.0, + initial_std=0.02), + use_global_stats=False) + + return fc_layer(input=hidden_bn, + name="gen_layer1", + size=sample_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + +if is_generator_training: + noise = data_layer(name="noise", size=noise_dim) + sample = generator(noise) + +if is_discriminator_training: + sample = data_layer(name="sample", size=sample_dim) + +if is_generator_training or is_discriminator_training: + label = data_layer(name="label", size=1) + prob = discriminator(sample) + cost = cross_entropy(input=prob, label=label) + classification_error_evaluator(input=prob, label=label, name=mode+'_error') + outputs(cost) + +if is_generator: + noise = data_layer(name="noise", size=noise_dim) + outputs(generator(noise)) diff --git a/demo/gan/gan_conf_image.py b/demo/gan/gan_conf_image.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5910e9f02d7aac59207fdaa0222d01ac3bf609 --- /dev/null +++ b/demo/gan/gan_conf_image.py @@ -0,0 +1,264 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.trainer_config_helpers import * + +mode = get_config_arg("mode", str, "generator") +dataSource = get_config_arg("data", str, "mnist") +assert mode in set(["generator", + "discriminator", + "generator_training", + "discriminator_training"]) + +is_generator_training = mode == "generator_training" +is_discriminator_training = mode == "discriminator_training" +is_generator = mode == "generator" +is_discriminator = mode == "discriminator" + +# The network structure below follows the dcgan paper +# (https://arxiv.org/abs/1511.06434) + +print('mode=%s' % mode) +# the dim of the noise (z) as the input of the generator network +noise_dim = 100 +# the number of filters in the layer in generator/discriminator that is +# closet to the image +gf_dim = 64 +df_dim = 64 +if dataSource == "mnist": + sample_dim = 28 # image dim + c_dim = 1 # image color +else: + sample_dim = 32 + c_dim = 3 +s2, s4 = int(sample_dim/2), int(sample_dim/4), +s8, s16 = int(sample_dim/8), int(sample_dim/16) + +settings( + batch_size=128, + learning_rate=2e-4, + learning_method=AdamOptimizer(beta1=0.5) +) + +def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name, + param_attr, bias_attr, param_attr_bn, bn, trans=False, + act=ReluActivation()): + + """ + conv_bn is a utility function that constructs a convolution/deconv layer + with an optional batch_norm layer + + :param bn: whether to use batch_norm_layer + :type bn: bool + :param trans: whether to use conv (False) or deconv (True) + :type trans: bool + """ + + # calculate the filter_size and padding size based on the given + # imgSize and ouput size + tmp = imgSize - (output_x - 1) * stride + if tmp <= 1 or tmp > 5: + raise ValueError("conv input-output dimension does not fit") + elif tmp <= 3: + filter_size = tmp + 2 + padding = 1 + else: + filter_size = tmp + padding = 0 + + print (imgSize, output_x, stride, filter_size, padding) + + if trans: + nameApx = "_conv" + else: + nameApx = "_convt" + + if bn: + conv = img_conv_layer(input, filter_size=filter_size, + num_filters=num_filters, + name=name + nameApx, num_channels=channels, + act=LinearActivation(), groups=1, stride=stride, + padding=padding, bias_attr=bias_attr, + param_attr=param_attr, shared_biases=True, layer_attr=None, + filter_size_y=None, stride_y=None, padding_y=None, + trans=trans) + + conv_bn = batch_norm_layer(conv, + act=act, + name=name + nameApx + "_bn", + bias_attr=bias_attr, + param_attr=param_attr_bn, + use_global_stats=False) + + return conv_bn + else: + conv = img_conv_layer(input, filter_size=filter_size, + num_filters=num_filters, + name=name + nameApx, num_channels=channels, + act=act, groups=1, stride=stride, + padding=padding, bias_attr=bias_attr, + param_attr=param_attr, shared_biases=True, layer_attr=None, + filter_size_y=None, stride_y=None, padding_y=None, + trans=trans) + return conv + +def generator(noise): + """ + generator generates a sample given noise + """ + param_attr = ParamAttr(is_static=is_discriminator_training, + initial_mean=0.0, + initial_std=0.02) + bias_attr = ParamAttr(is_static=is_discriminator_training, + initial_mean=0.0, + initial_std=0.0) + + param_attr_bn=ParamAttr(is_static=is_discriminator_training, + initial_mean=1.0, + initial_std=0.02) + + h1 = fc_layer(input=noise, + name="gen_layer_h1", + size=s8 * s8 * gf_dim * 4, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + h1_bn = batch_norm_layer(h1, + act=ReluActivation(), + name="gen_layer_h1_bn", + bias_attr=bias_attr, + param_attr=param_attr_bn, + use_global_stats=False) + + h2_bn = conv_bn(h1_bn, + channels=gf_dim*4, + output_x=s8, + num_filters=gf_dim*2, + imgSize=s4, + stride=2, + name="gen_layer_h2", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True, + trans=True) + + h3_bn = conv_bn(h2_bn, + channels=gf_dim*2, + output_x=s4, + num_filters=gf_dim, + imgSize=s2, + stride=2, + name="gen_layer_h3", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True, + trans=True) + + + return conv_bn(h3_bn, + channels=gf_dim, + output_x=s2, + num_filters=c_dim, + imgSize=sample_dim, + stride=2, + name="gen_layer_h4", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=False, + trans=True, + act=TanhActivation()) + + +def discriminator(sample): + """ + discriminator ouputs the probablity of a sample is from generator + or real data. + The output has two dimenstional: dimension 0 is the probablity + of the sample is from generator and dimension 1 is the probabblity + of the sample is from real data. + """ + param_attr = ParamAttr(is_static=is_generator_training, + initial_mean=0.0, + initial_std=0.02) + bias_attr = ParamAttr(is_static=is_generator_training, + initial_mean=0.0, + initial_std=0.0) + + param_attr_bn=ParamAttr(is_static=is_generator_training, + initial_mean=1.0, + initial_std=0.02) + + h0 = conv_bn(sample, + channels=c_dim, + imgSize=sample_dim, + num_filters=df_dim, + output_x=s2, + stride=2, + name="dis_h0", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=False) + + h1_bn = conv_bn(h0, + channels=df_dim, + imgSize=s2, + num_filters=df_dim*2, + output_x=s4, + stride=2, + name="dis_h1", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True) + + h2_bn = conv_bn(h1_bn, + channels=df_dim*2, + imgSize=s4, + num_filters=df_dim*4, + output_x=s8, + stride=2, + name="dis_h2", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True) + + return fc_layer(input=h2_bn, name="dis_prob", size=2, + bias_attr=bias_attr, + param_attr=param_attr, + act=SoftmaxActivation()) + + + +if is_generator_training: + noise = data_layer(name="noise", size=noise_dim) + sample = generator(noise) + +if is_discriminator_training: + sample = data_layer(name="sample", size=sample_dim * sample_dim*c_dim) + +if is_generator_training or is_discriminator_training: + label = data_layer(name="label", size=1) + prob = discriminator(sample) + cost = cross_entropy(input=prob, label=label) + classification_error_evaluator(input=prob, label=label, name=mode+'_error') + outputs(cost) + +if is_generator: + noise = data_layer(name="noise", size=noise_dim) + outputs(generator(noise)) diff --git a/demo/gan/gan_trainer.py b/demo/gan/gan_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..72699952b961cb5bf6ac14dd65eee1aeab5e2a7c --- /dev/null +++ b/demo/gan/gan_trainer.py @@ -0,0 +1,329 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import numpy +import cPickle +import sys,os +from PIL import Image + +from paddle.trainer.config_parser import parse_config +from paddle.trainer.config_parser import logger +import py_paddle.swig_paddle as api +import matplotlib.pyplot as plt + +def plot2DScatter(data, outputfile): + ''' + Plot the data as a 2D scatter plot and save to outputfile + data needs to be two dimensinoal + ''' + x = data[:, 0] + y = data[:, 1] + logger.info("The mean vector is %s" % numpy.mean(data, 0)) + logger.info("The std vector is %s" % numpy.std(data, 0)) + + heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) + extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] + + plt.clf() + plt.scatter(x, y) + plt.savefig(outputfile, bbox_inches='tight') + +def CHECK_EQ(a, b): + assert a == b, "a=%s, b=%s" % (a, b) + +def copy_shared_parameters(src, dst): + ''' + copy the parameters from src to dst + :param src: the source of the parameters + :type src: GradientMachine + :param dst: the destination of the parameters + :type dst: GradientMachine + ''' + src_params = [src.getParameter(i) + for i in xrange(src.getParameterSize())] + src_params = dict([(p.getName(), p) for p in src_params]) + + + for i in xrange(dst.getParameterSize()): + dst_param = dst.getParameter(i) + src_param = src_params.get(dst_param.getName(), None) + if src_param is None: + continue + src_value = src_param.getBuf(api.PARAMETER_VALUE) + dst_value = dst_param.getBuf(api.PARAMETER_VALUE) + CHECK_EQ(len(src_value), len(dst_value)) + dst_value.copyFrom(src_value) + dst_param.setValueUpdated() + +def print_parameters(src): + src_params = [src.getParameter(i) + for i in xrange(src.getParameterSize())] + + print "***************" + for p in src_params: + print "Name is %s" % p.getName() + print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray() + +def load_mnist_data(imageFile): + f = open(imageFile, "rb") + f.read(16) + + # Define number of samples for train/test + if "train" in imageFile: + n = 60000 + else: + n = 10000 + + data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)) + data = data / 255.0 * 2.0 - 1.0 + + f.close() + return data.astype('float32') + +def load_cifar_data(cifar_path): + batch_size = 10000 + data = numpy.zeros((5*batch_size, 32*32*3), dtype = "float32") + for i in range(1, 6): + file = cifar_path + "/data_batch_" + str(i) + fo = open(file, 'rb') + dict = cPickle.load(fo) + fo.close() + data[(i - 1)*batch_size:(i*batch_size), :] = dict["data"] + + data = data / 255.0 * 2.0 - 1.0 + return data + +# synthesize 2-D uniform data +def load_uniform_data(): + data = numpy.random.rand(1000000, 2).astype('float32') + return data + +def merge(images, size): + if images.shape[1] == 28*28: + h, w, c = 28, 28, 1 + else: + h, w, c = 32, 32, 3 + img = numpy.zeros((h * size[0], w * size[1], c)) + for idx in xrange(size[0] * size[1]): + i = idx % size[1] + j = idx // size[1] + img[j*h:j*h+h, i*w:i*w+w, :] = \ + ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) + return img.astype('uint8') + +def save_images(images, path): + merged_img = merge(images, [8, 8]) + if merged_img.shape[2] == 1: + im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') + else: + im = Image.fromarray(merged_img, mode="RGB") + im.save(path) + +def get_real_samples(batch_size, data_np): + return data_np[numpy.random.choice(data_np.shape[0], batch_size, + replace=False),:] + +def get_noise(batch_size, noise_dim): + return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') + +def get_fake_samples(generator_machine, batch_size, noise): + gen_inputs = api.Arguments.createArguments(1) + gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) + gen_outputs = api.Arguments.createArguments(0) + generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) + fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() + return fake_samples + +def get_training_loss(training_machine, inputs): + outputs = api.Arguments.createArguments(0) + training_machine.forward(inputs, outputs, api.PASS_TEST) + loss = outputs.getSlotValue(0).copyToNumpyMat() + return numpy.mean(loss) + +def prepare_discriminator_data_batch_pos(batch_size, data_np): + real_samples = get_real_samples(batch_size, data_np) + labels = numpy.ones(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) + return inputs + +def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): + fake_samples = get_fake_samples(generator_machine, batch_size, noise) + labels = numpy.zeros(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) + return inputs + +def prepare_generator_data_batch(batch_size, noise): + label = numpy.ones(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label)) + return inputs + + +def find(iterable, cond): + for item in iterable: + if cond(item): + return item + return None + + +def get_layer_size(model_conf, layer_name): + layer_conf = find(model_conf.layers, lambda x: x.name == layer_name) + assert layer_conf is not None, "Cannot find '%s' layer" % layer_name + return layer_conf.size + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform") + parser.add_argument("--use_gpu", default="1", + help="1 means use gpu for training") + parser.add_argument("--gpu_id", default="0", + help="the gpu_id parameter") + args = parser.parse_args() + data_source = args.data_source + use_gpu = args.use_gpu + assert data_source in ["mnist", "cifar", "uniform"] + assert use_gpu in ["0", "1"] + + if not os.path.exists("./%s_samples/" % data_source): + os.makedirs("./%s_samples/" % data_source) + + if not os.path.exists("./%s_params/" % data_source): + os.makedirs("./%s_params/" % data_source) + + api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', '--log_period=100', + '--gpu_id=' + args.gpu_id, '--save_dir=' + "./%s_params/" % data_source) + + if data_source == "uniform": + conf = "gan_conf.py" + num_iter = 10000 + else: + conf = "gan_conf_image.py" + num_iter = 1000 + + gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source) + dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source) + generator_conf = parse_config(conf, "mode=generator,data=" + data_source) + batch_size = dis_conf.opt_config.batch_size + noise_dim = get_layer_size(gen_conf.model_config, "noise") + + if data_source == "mnist": + data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte") + elif data_source == "cifar": + data_np = load_cifar_data("./data/cifar-10-batches-py/") + else: + data_np = load_uniform_data() + + # this creates a gradient machine for discriminator + dis_training_machine = api.GradientMachine.createFromConfigProto( + dis_conf.model_config) + # this create a gradient machine for generator + gen_training_machine = api.GradientMachine.createFromConfigProto( + gen_conf.model_config) + + # generator_machine is used to generate data only, which is used for + # training discriminator + logger.info(str(generator_conf.model_config)) + generator_machine = api.GradientMachine.createFromConfigProto( + generator_conf.model_config) + + dis_trainer = api.Trainer.create( + dis_conf, dis_training_machine) + + gen_trainer = api.Trainer.create( + gen_conf, gen_training_machine) + + dis_trainer.startTrain() + gen_trainer.startTrain() + + # Sync parameters between networks (GradientMachine) at the beginning + copy_shared_parameters(gen_training_machine, dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) + + # constrain that either discriminator or generator can not be trained + # consecutively more than MAX_strike times + curr_train = "dis" + curr_strike = 0 + MAX_strike = 5 + + for train_pass in xrange(100): + dis_trainer.startTrainPass() + gen_trainer.startTrainPass() + for i in xrange(num_iter): + # Do forward pass in discriminator to get the dis_loss + noise = get_noise(batch_size, noise_dim) + data_batch_dis_pos = prepare_discriminator_data_batch_pos( + batch_size, data_np) + dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) + + data_batch_dis_neg = prepare_discriminator_data_batch_neg( + generator_machine, batch_size, noise) + dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) + + dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 + + # Do forward pass in generator to get the gen_loss + data_batch_gen = prepare_generator_data_batch( + batch_size, noise) + gen_loss = get_training_loss(gen_training_machine, data_batch_gen) + + if i % 100 == 0: + print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) + print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) + + # Decide which network to train based on the training history + # And the relative size of the loss + if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \ + ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): + if curr_train == "dis": + curr_strike += 1 + else: + curr_train = "dis" + curr_strike = 1 + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) + copy_shared_parameters(dis_training_machine, gen_training_machine) + + else: + if curr_train == "gen": + curr_strike += 1 + else: + curr_train = "gen" + curr_strike = 1 + gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) + # TODO: add API for paddle to allow true parameter sharing between different GradientMachines + # so that we do not need to copy shared parameters. + copy_shared_parameters(gen_training_machine, dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) + + dis_trainer.finishTrainPass() + gen_trainer.finishTrainPass() + # At the end of each pass, save the generated samples/images + fake_samples = get_fake_samples(generator_machine, batch_size, noise) + if data_source == "uniform": + plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass)) + else: + save_images(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass)) + dis_trainer.finishTrain() + gen_trainer.finishTrain() + +if __name__ == '__main__': + main() diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index b539374cd4aa5a9510cdb728c1b22edf65a9f880..bd1fdffe8984e8b8804c576890ec6a37dc7cf574 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -27,11 +27,6 @@ Arguments* Arguments::createArguments(size_t slotNum) { void Arguments::resize(size_t slotNum) { m->outputs.resize(slotNum); } -Matrix* Arguments::getSlotValue(size_t idx) const throw(RangeError) { - auto& a = m->getArg(idx); - return Matrix::createByPaddleMatrixPtr(&a.value); -} - Arguments::Arguments() : m(new ArgumentsPrivate()) {} Arguments::~Arguments() { delete m; } @@ -43,6 +38,16 @@ Arguments* Arguments::createByPaddleArgumentVector(void* ptr) { return args; } +Matrix* Arguments::getSlotValue(size_t idx) const throw(RangeError) { + auto& a = m->getArg(idx); + return Matrix::createByPaddleMatrixPtr(&a.value); +} + +Matrix* Arguments::getSlotGrad(size_t idx) const throw(RangeError) { + auto& a = m->getArg(idx); + return Matrix::createByPaddleMatrixPtr(&a.grad); +} + IVector* Arguments::getSlotIds(size_t idx) const throw(RangeError) { auto& a = m->getArg(idx); return IVector::createByPaddleVectorPtr(&a.ids); @@ -58,6 +63,11 @@ void Arguments::setSlotValue(size_t idx, Matrix* mat) throw(RangeError) { a.value = m->cast(mat->getSharedPtr()); } +void Arguments::setSlotGrad(size_t idx, Matrix* mat) throw(RangeError) { + auto& a = m->getArg(idx); + a.grad = m->cast(mat->getSharedPtr()); +} + void Arguments::setSlotIn(size_t idx, Matrix* mat) throw(RangeError) { auto& a = m->getArg(idx); a.in = m->cast(mat->getSharedPtr()); diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 6a0fbc537d9345f2221ab65d90733f4696be6880..9194a6371be9e00c037967464ee2b63c1e4f6192 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -193,5 +193,4 @@ namespace std { %ignore OptimizationConfigPrivate; %ignore ParameterTraverseCallbackPrivate; %include "utils/GlobalConstants.h" -%include "api/PaddleAPI.h" - +%include "api/PaddleAPI.h" \ No newline at end of file diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index c07facdb1292b34ac31247160a4347ea359e718b..a125934fc17ceb2df3b4fd89538e7a79eee3761e 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -156,12 +156,15 @@ public: * @param dim1 dimension of data. * @param dim2 dimension of data. * @param copy true if copy into a new matrix, false will create - * matrix inplace. + * matrix inplace. copy = false should be used with extreme + * care because Matrix will share the memory with the given + * numpy array. If the numpy array object is no longer valid, + * the memory space will not be usable. */ static Matrix* createCpuDenseFromNumpy(float* data, int dim1, int dim2, - bool copy = false); + bool copy = true); /// Create Gpu Dense Matrix from numpy matrix, dtype=float32 static Matrix* createGpuDenseFromNumpy(float* data, int dim1, int dim2); @@ -271,11 +274,18 @@ public: */ static Vector* createCpuVectorFromNumpy(float* data, int dim, - bool copy = false); + bool copy = true); /// Create Gpu Vector from numpy array, which dtype=float32 static Vector* createGpuVectorFromNumpy(float* data, int dim); + /** + * copy from another vector + * throw(RangeError) if size of src vector is different from size of this + * vector + */ + void copyFrom(Vector* src) throw(RangeError); + /// Cast to numpy array inplace. void toNumpyArrayInplace(float** view_data, int* dim1) throw(UnsupportError); @@ -339,7 +349,7 @@ public: */ static IVector* createCpuVectorFromNumpy(int* data, int dim, - bool copy = false); + bool copy = true); /** * Create Gpu IVector from numpy array, which dtype=int32 */ @@ -418,6 +428,7 @@ public: * the param idx is the slot id */ Matrix* getSlotValue(size_t idx) const throw(RangeError); + Matrix* getSlotGrad(size_t idx) const throw(RangeError); IVector* getSlotIds(size_t idx) const throw(RangeError); Matrix* getSlotIn(size_t idx) const throw(RangeError); IVector* getSlotSequenceStartPositions(size_t idx) const throw(RangeError); @@ -434,6 +445,7 @@ public: * The other param is the input Matrix or vector. */ void setSlotValue(size_t idx, Matrix* mat) throw(RangeError); + void setSlotGrad(size_t idx, Matrix* mat) throw(RangeError); void setSlotIn(size_t idx, Matrix* mat) throw(RangeError); void setSlotIds(size_t idx, IVector* vec) throw(RangeError); void setSlotSequenceStartPositions(size_t idx, @@ -535,6 +547,7 @@ public: size_t getID() const; ParameterConfig* getConfig(); + void setValueUpdated(); private: static Parameter* createFromRawPtr(void* ptr); diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index c5876bb1c71438578831ffffd85840c706b6224c..9c30ef6ff421235e84896813c701da5d8bfe7af9 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -68,3 +68,5 @@ ParameterConfig* Parameter::getConfig() { } size_t Parameter::getID() const { return m->getPtr()->getID(); } + +void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } diff --git a/paddle/api/Vector.cpp b/paddle/api/Vector.cpp index cc1c098223826a06fea291a95730d7fc1fd1beb3..74c9ff8dc7373f2beb6e6faaf951678038803c56 100644 --- a/paddle/api/Vector.cpp +++ b/paddle/api/Vector.cpp @@ -281,6 +281,13 @@ FloatArray Vector::getData() const { } } +void Vector::copyFrom(Vector* src) throw(RangeError) { + if (src->m->vec->getSize() != m->vec->getSize()) { + throw RangeError(); + } + m->vec->copyFrom(*src->m->vec); +} + bool Vector::isGpu() const { return std::dynamic_pointer_cast(m->vec) != nullptr; } diff --git a/paddle/api/test/testMatrix.py b/paddle/api/test/testMatrix.py index 0432345edd659f13bddb1b99f62622c5ea64a4cb..8b0da626928e292c392142a1c25c6bd8f677372b 100644 --- a/paddle/api/test/testMatrix.py +++ b/paddle/api/test/testMatrix.py @@ -68,7 +68,7 @@ class TestMatrix(unittest.TestCase): def test_numpyCpu(self): numpy_mat = np.matrix([[1, 2], [3, 4], [5, 6]], dtype="float32") - m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat) + m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, copy=False) self.assertEqual((int(m.getHeight()), int(m.getWidth())), numpy_mat.shape) diff --git a/paddle/api/test/testVector.py b/paddle/api/test/testVector.py index 48aaa1d73da9e6c207ad5fa2be14a531267bd901..963359236d5e27ac569c00fd82b9a58f44eee4c9 100644 --- a/paddle/api/test/testVector.py +++ b/paddle/api/test/testVector.py @@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase): def test_cpu_numpy(self): vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32") - iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec) + iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, copy=False) self.assertEqual(vec.shape[0], int(iv.__len__())) vec[4] = 832 for i in xrange(len(iv)): @@ -107,7 +107,7 @@ class TestVector(unittest.TestCase): def testCpuNumpy(self): numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32") - vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr) + vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, copy=False) assert isinstance(vec, swig_paddle.Vector) numpy_arr[0] = 0.1 for n, v in zip(numpy_arr, vec): @@ -152,4 +152,4 @@ if __name__ == '__main__': unittest.TextTestRunner().run(suite) if swig_paddle.isGpuVersion(): swig_paddle.setUseGpu(True) - unittest.main() \ No newline at end of file + unittest.main() diff --git a/paddle/api/test/util.py b/paddle/api/test/util.py index 93a01b242f9f9a4c939cfbf9c4c7c47bb0e4e9cf..dbcdba5bf27c2fd7df95f8838ad5fdcd131cccf1 100644 --- a/paddle/api/test/util.py +++ b/paddle/api/test/util.py @@ -24,7 +24,9 @@ def doubleEqual(a, b): def __readFromFile(): for i in xrange(10002): - yield np.random.rand(784), random.randint(0, 9) + label = np.random.randint(0, 9) + sample = np.random.rand(784) + 0.1 * label + yield sample, label def loadMNISTTrainData(batch_size=100): diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp index 2d5bcff29fd5ad33c8eba85fc803bbf89803782e..6381f20a63c6b4ca24245cd6f30e4defda279de6 100644 --- a/paddle/gserver/layers/BatchNormBaseLayer.cpp +++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp @@ -68,10 +68,10 @@ void BatchNormBaseLayer::calFeatureMapSize() { } else { imageH_ = inputLayers_[0]->getOutput().getFrameHeight(); imageW_ = inputLayers_[0]->getOutput().getFrameWidth(); + getOutput().setFrameHeight(imageH_); + getOutput().setFrameWidth(imageW_); } imgPixels_ = imageH_ * imageW_; - getOutput().setFrameHeight(imageH_); - getOutput().setFrameWidth(imageW_); } } // namespace paddle diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 79741bef2fe19606b0cb989c9a1d11e737cc6063..9d427467e784a4c492182153dc88001b26791687 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -39,9 +39,17 @@ add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp LayerGradUtil.cpp TestUtil.cpp) - + add_test(NAME test_ConvUnify COMMAND test_ConvUnify) +################# test_BatchNorm ####################### +add_unittest_without_exec(test_BatchNorm + test_BatchNorm.cpp + LayerGradUtil.cpp + TestUtil.cpp) + +add_test(NAME test_BatchNorm + COMMAND test_BatchNorm) ################## test_Evaluator ####################### add_unittest(test_Evaluator test_Evaluator.cpp diff --git a/paddle/gserver/tests/test_BatchNorm.cpp b/paddle/gserver/tests/test_BatchNorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0cb6f58dc000bd0fb408e6f3a3aa4ff4240adf26 --- /dev/null +++ b/paddle/gserver/tests/test_BatchNorm.cpp @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/gserver/layers/DataLayer.h" +#include "ModelConfig.pb.h" +#include "paddle/trainer/Trainer.h" +#include "paddle/utils/GlobalConstants.h" +#include "paddle/gserver/layers/ExpandConvTransLayer.h" + +#include "TestUtil.h" +#include "LayerGradUtil.h" + +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +P_DECLARE_bool(use_gpu); +P_DECLARE_int32(gpu_id); +P_DECLARE_double(checkgrad_eps); +P_DECLARE_bool(thread_local_rand_use_global_seed); +P_DECLARE_bool(prev_batch_state); + +// Test that the batchNormLayer can be followed by a ConvLayer +TEST(Layer, batchNorm) { + FLAGS_use_gpu = false; + TestConfig configBN; + const int CHANNELS = 6272; + const int IMG_SIZE = 1; + configBN.layerConfig.set_type("batch_norm"); + configBN.layerConfig.set_name("bn"); + configBN.layerConfig.set_size(CHANNELS * IMG_SIZE * IMG_SIZE); + configBN.layerConfig.set_active_type("relu"); + configBN.biasSize = CHANNELS; + configBN.inputDefs.push_back({INPUT_DATA, "layer_0", + /* dim= */ IMG_SIZE * IMG_SIZE * CHANNELS, + /* paraSize= */ CHANNELS}); + + configBN.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", + 1, CHANNELS}); + configBN.inputDefs.back().isStatic = true; + configBN.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", + 1, CHANNELS}); + configBN.inputDefs.back().isStatic = true; + + LayerInputConfig* input = configBN.layerConfig.add_inputs(); + configBN.layerConfig.add_inputs(); + configBN.layerConfig.add_inputs(); + + ImageConfig* img_conf = input->mutable_image_conf(); + img_conf->set_channels(CHANNELS); + img_conf->set_img_size(IMG_SIZE); + + // Setting up conv-layer config + TestConfig config; + config.biasSize = 64; + config.layerConfig.set_type("exconv"); + config.layerConfig.set_num_filters(64); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + config.inputDefs.push_back({INPUT_DATA, "bn", 6272, 204800}); + input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + conv->set_filter_size(5); + conv->set_filter_size_y(5); + conv->set_channels(128); + conv->set_padding(1); + conv->set_padding_y(1); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + conv->set_img_size(7); + conv->set_output_x(3); + config.layerConfig.set_size(conv->output_x() * conv->output_x() * + config.layerConfig.num_filters()); + config.layerConfig.set_name("conv"); + + // data layer initialize + std::vector dataLayers; + LayerMap layerMap; + vector datas; + initDataLayer(configBN, &dataLayers, &datas, &layerMap, "batch_norm", + 100, false, false); + // test layer initialize + std::vector parameters; + LayerPtr bnLayer; + initTestLayer(configBN, &layerMap, ¶meters, &bnLayer); + + std::vector parameters2; + LayerPtr convLayer; + initTestLayer(config, &layerMap, ¶meters2, &convLayer); + + bnLayer->forward(PASS_GC); + convLayer->forward(PASS_GC); + + CHECK_EQ(convLayer->getOutputValue()->getHeight(), 100); + CHECK_EQ(convLayer->getOutputValue()->getWidth(), 576); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + initMain(argc, argv); + FLAGS_thread_local_rand_use_global_seed = true; + srand(1); + return RUN_ALL_TESTS(); +} diff --git a/paddle/py_paddle/util.py b/paddle/py_paddle/util.py index e1f310580f95cfb210ba89589bab668433818b23..35a355ef29cebd84fd34e00cee05218220b2eb43 100644 --- a/paddle/py_paddle/util.py +++ b/paddle/py_paddle/util.py @@ -559,10 +559,10 @@ def __monkey_patch_trainer__(): def monkeypatches(): - patches = [ - __monkeypatch_init_paddle__, __monkeypatch_gradient_machine__, - __monkey_patch_protobuf_objects__, __monkey_patch_parameter__, - __monkey_patch_trainer__ - ] + patches = [__monkeypatch_init_paddle__, + __monkeypatch_gradient_machine__, + __monkey_patch_protobuf_objects__, + __monkey_patch_parameter__, + __monkey_patch_trainer__] for patch in patches: patch()