diff --git a/gan/README.md b/gan/README.md index d059941b03d0b653d538156056cfb6095c12c05d..5bac5767c8965a68f9fc7c2f6cb81eba281c94e3 100644 --- a/gan/README.md +++ b/gan/README.md @@ -1,18 +1,18 @@ # 对抗式生成网络 ## 背景介绍 -本章我们介绍对抗式生成网络,也称为Generative Adversarial Network(GAN) \[[1](#参考文献)\]。对抗式生成网络是生成模型 (generative model) 的一种,可以用非监督学习的办法来学习输入数据的分布,从而能达到产生和输入数据拥有同样概率分布的人造数据。 +本章我们介绍对抗式生成网络,也称为Generative Adversarial Network(GAN) \[[1](#参考文献)\]。对抗式生成网络是生成模型 (generative model) 的一种,可以用非监督学习的办法来学习输入数据的分布,从而能达到产生和输入数据拥有同样概率分布的人造数据。这样的学习能力可以帮助机器完成图片自动生成、图像去噪、缺失图像补全和图像超分辨生成等工作。 -现在大部分利用深度学习成功的例子都是在监督学习的条件下,把高维数据映射到一种低维空间表示(representation)里来进行分类(可参见前面几章的介绍)。这种方法也叫区分模型(discriminative model)。但用这种方法学到的表示一般只是对那一种目标任务有效果,而不能很好的转移到别的任务。同时监督学习的训练需要大量标记好的数据,很多时候不是很容易得到。 +现在大部分利用深度学习成功的例子都是在监督学习的条件下,把高维数据映射到一种低维空间表示(representation)里来进行分类(可参见前面几章的介绍)。这种方法也叫判别模型(discriminative model),它直接对条件概率P(y|x)建模。像我们的前八章,都是判别模型。但用这种方法学到的表示一般只是对那一种目标任务有效果,而不能很好的转移到别的任务。同时监督学习的训练需要大量标记好的数据,很多时候不是很容易得到。 -所以为了能够从大量无标记数据里学到通用有效的表示,人们发明了另一种模型叫作生成模型。这个方法背后的基本想法是,如果一个模型它能够生成和真实数据非常相近的数据,那么很可能它就学到了对于这种数据的一种很有效的表示。生成模型另一些实际用途包括,图像去噪,缺失图像补全,图像超分辨生成等等。在标记数据不够的时候,还可以用生成模型生成的数据来预训练模型。 +生成模型背后的基本想法是,如果一个模型它能够生成和真实数据非常相近的数据,那么很可能它就学到了对于这种数据的一种很有效的表示。生成模型另一些实际用途包括,图像去噪,缺失图像补全,图像超分辨生成等等。在标记数据不够的时候,还可以用生成模型生成的数据来预训练模型。 -现在常用的生成模型大致有两种类型,一种是变分自编码器(variational autoencoder)\[[3](#参考文献)\],它是在概率图模型(probabilistic graphical model)的框架下面搭建了一个生成模型,对数据有完整的概率描述,训练时是通过调节参数来最大化数据的概率。用这种方法产生的图片,虽然所对应的概率高,但很多时候看起来都比较模糊。为了解决这个问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。 +近年来有一些有趣的图片生成模型,一种是变分自编码器(variational autoencoder)\[[3](#参考文献)\],它是在概率图模型(probabilistic graphical model)的框架下面搭建了一个生成模型,对数据有完整的概率描述(即对P(x)进行建模),训练时是通过调节参数来最大化数据的概率。用这种方法产生的图片,虽然所对应的概率高,但很多时候看起来都比较模糊。为了解决这个问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。 在本章里,我们展对抗式生产网络的细节,以及如何用PaddlePaddle训练一个GAN模型。 ## 效果展示 -一个简单的例子是向对抗式生成网络输入MNIST手写数字的图片,然后让模型自己产生类似的手写数字图片。由训练好的GAN模型产生的手写数字图片的例子画在图1中。 +一个简单的例子是训练对抗式生成网络,使其学习产生MNIST手写数字的图片。由训练好的GAN模型产生的手写数字图片的例子画在图1中。


@@ -20,11 +20,11 @@

## 模型概览 -对抗式生成网络的大致结构在图2中画出,它由两部分组成:一个生成器(Generator)G 和一个分类器(Discriminator, 也称判别器)D,两者都是有多层神经网络构成的。生成器的输入是一个多维的已知概率分布的噪音 z,通过神经网络变换,输出伪样本。分类器输的输入是真样本和伪样本,输出为分类结果为真样本和伪样本的概率。训练时生成器和分类器处于相互竞争对抗状态,生成器会尽量生成和真样本相近的伪样本让分类器无法分辨真伪,而分类器则会尽力去分辨伪样本。具体的损失函数如下: +对抗式生成网络的大致结构在图2中画出,它由两部分组成:一个生成器(Generator)G 和一个分类器(Discriminator, 也称判别器)D,两者都是有多层神经网络构成的。生成器的输入是一个多维的已知概率分布的噪音 z(噪音的概率分布不取决于待生成样本,如可以服从正态分布),通过神经网络变换,输出伪样本。分类器输的输入是真样本和伪样本,输出为分类结果为真样本和伪样本的概率。训练时生成器和分类器处于相互竞争对抗状态,生成器会尽量生成和真样本相近的伪样本让分类器无法分辨真伪,而分类器则会尽力去分辨伪样本。具体的损失函数如下: $$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) + log(1-D(G(z^i)))]$$ -这个损失函数里面$x$是真实数据,$z$是已知概率分布的噪音。所以这个损失函数所代表的意义就是真实数据被分类为真的概率加上伪数据被分类为假的概率。分类器 D 目标是增加这个函数值,故公式里为max,而生成器 G 目标是减少这个函数值,故公式里为min。 +其中$x$是真实数据,$z$是已知概率分布的噪音。所以这个损失函数所代表的意义就是真实数据被分类为真的概率加上伪数据被分类为假的概率。分类器 D 目标是增加这个函数值,故公式里为max,而生成器 G 目标是减少这个函数值,故公式里为min。


@@ -64,6 +64,13 @@ $cd data/ $./get_mnist_data.sh ``` +另一种更真实的图片数据是Cifar-10,可由下面的代码下载: + +```bash +$cd data/ +$./download_cifar.sh +``` + ## 模型配置说明 由于对抗式生产网络涉及到多个神经网络,所以必须用paddle Python API来训练。下面的介绍也可以部分的拿来当作paddle Python API的使用说明。 diff --git a/gan/data/download_cifar.sh b/gan/data/download_cifar.sh new file mode 100755 index 0000000000000000000000000000000000000000..bbadc7c10c73e45a0948018b8812f79040d14bc4 --- /dev/null +++ b/gan/data/download_cifar.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +tar zxf cifar-10-python.tar.gz +rm cifar-10-python.tar.gz diff --git a/gan/data/get_mnist_data.sh b/gan/data/get_mnist_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..a77c81bf5af9ddb6634ff89460797ca543c5e517 --- /dev/null +++ b/gan/data/get_mnist_data.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env sh +# This script downloads the mnist data and unzips it. +set -e +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +rm -rf "$DIR/mnist_data" +mkdir "$DIR/mnist_data" +cd "$DIR/mnist_data" + +echo "Downloading..." + +for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte +do + if [ ! -e $fname ]; then + wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz + gunzip ${fname}.gz + fi +done diff --git a/gan/gan_conf.py b/gan/gan_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..86ac2dffe5f4490a88e12d1fa5e8cd9fa61a69f4 --- /dev/null +++ b/gan/gan_conf.py @@ -0,0 +1,151 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.trainer_config_helpers import * + +mode = get_config_arg("mode", str, "generator") +assert mode in set([ + "generator", "discriminator", "generator_training", "discriminator_training" +]) + +is_generator_training = mode == "generator_training" +is_discriminator_training = mode == "discriminator_training" +is_generator = mode == "generator" +is_discriminator = mode == "discriminator" + +# The network structure below follows the ref https://arxiv.org/abs/1406.2661 +# Here we used two hidden layers and batch_norm + +print('mode=%s' % mode) +# the dim of the noise (z) as the input of the generator network +noise_dim = 10 +# the dim of the hidden layer +hidden_dim = 10 +# the dim of the generated sample +sample_dim = 2 + +settings( + batch_size=128, + learning_rate=1e-4, + learning_method=AdamOptimizer(beta1=0.5)) + + +def discriminator(sample): + """ + discriminator ouputs the probablity of a sample is from generator + or real data. + The output has two dimenstional: dimension 0 is the probablity + of the sample is from generator and dimension 1 is the probabblity + of the sample is from real data. + """ + param_attr = ParamAttr(is_static=is_generator_training) + bias_attr = ParamAttr( + is_static=is_generator_training, initial_mean=1.0, initial_std=0) + + hidden = fc_layer( + input=sample, + name="dis_hidden", + size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=ReluActivation()) + + hidden2 = fc_layer( + input=hidden, + name="dis_hidden2", + size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + hidden_bn = batch_norm_layer( + hidden2, + act=ReluActivation(), + name="dis_hidden_bn", + bias_attr=bias_attr, + param_attr=ParamAttr( + is_static=is_generator_training, initial_mean=1.0, + initial_std=0.02), + use_global_stats=False) + + return fc_layer( + input=hidden_bn, + name="dis_prob", + size=2, + bias_attr=bias_attr, + param_attr=param_attr, + act=SoftmaxActivation()) + + +def generator(noise): + """ + generator generates a sample given noise + """ + param_attr = ParamAttr(is_static=is_discriminator_training) + bias_attr = ParamAttr( + is_static=is_discriminator_training, initial_mean=1.0, initial_std=0) + + hidden = fc_layer( + input=noise, + name="gen_layer_hidden", + size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=ReluActivation()) + + hidden2 = fc_layer( + input=hidden, + name="gen_hidden2", + size=hidden_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + hidden_bn = batch_norm_layer( + hidden2, + act=ReluActivation(), + name="gen_layer_hidden_bn", + bias_attr=bias_attr, + param_attr=ParamAttr( + is_static=is_discriminator_training, + initial_mean=1.0, + initial_std=0.02), + use_global_stats=False) + + return fc_layer( + input=hidden_bn, + name="gen_layer1", + size=sample_dim, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + +if is_generator_training: + noise = data_layer(name="noise", size=noise_dim) + sample = generator(noise) + +if is_discriminator_training: + sample = data_layer(name="sample", size=sample_dim) + +if is_generator_training or is_discriminator_training: + label = data_layer(name="label", size=1) + prob = discriminator(sample) + cost = cross_entropy(input=prob, label=label) + classification_error_evaluator( + input=prob, label=label, name=mode + '_error') + outputs(cost) + +if is_generator: + noise = data_layer(name="noise", size=noise_dim) + outputs(generator(noise)) diff --git a/gan/gan_conf_image.py b/gan/gan_conf_image.py new file mode 100644 index 0000000000000000000000000000000000000000..c469227994c1a84d1aa73e03bbc74ebeac41d30e --- /dev/null +++ b/gan/gan_conf_image.py @@ -0,0 +1,298 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.trainer_config_helpers import * + +mode = get_config_arg("mode", str, "generator") +dataSource = get_config_arg("data", str, "mnist") +assert mode in set([ + "generator", "discriminator", "generator_training", "discriminator_training" +]) + +is_generator_training = mode == "generator_training" +is_discriminator_training = mode == "discriminator_training" +is_generator = mode == "generator" +is_discriminator = mode == "discriminator" + +# The network structure below follows the dcgan paper +# (https://arxiv.org/abs/1511.06434) + +print('mode=%s' % mode) +# the dim of the noise (z) as the input of the generator network +noise_dim = 100 +# the number of filters in the layer in generator/discriminator that is +# closet to the image +gf_dim = 64 +df_dim = 64 +if dataSource == "mnist": + sample_dim = 28 # image dim + c_dim = 1 # image color +else: + sample_dim = 32 + c_dim = 3 +s2, s4 = int(sample_dim / 2), int(sample_dim / 4), +s8, s16 = int(sample_dim / 8), int(sample_dim / 16) + +settings( + batch_size=128, + learning_rate=2e-4, + learning_method=AdamOptimizer(beta1=0.5)) + + +def conv_bn(input, + channels, + imgSize, + num_filters, + output_x, + stride, + name, + param_attr, + bias_attr, + param_attr_bn, + bn, + trans=False, + act=ReluActivation()): + """ + conv_bn is a utility function that constructs a convolution/deconv layer + with an optional batch_norm layer + + :param bn: whether to use batch_norm_layer + :type bn: bool + :param trans: whether to use conv (False) or deconv (True) + :type trans: bool + """ + + # calculate the filter_size and padding size based on the given + # imgSize and ouput size + tmp = imgSize - (output_x - 1) * stride + if tmp <= 1 or tmp > 5: + raise ValueError("conv input-output dimension does not fit") + elif tmp <= 3: + filter_size = tmp + 2 + padding = 1 + else: + filter_size = tmp + padding = 0 + + print(imgSize, output_x, stride, filter_size, padding) + + if trans: + nameApx = "_convt" + else: + nameApx = "_conv" + + if bn: + conv = img_conv_layer( + input, + filter_size=filter_size, + num_filters=num_filters, + name=name + nameApx, + num_channels=channels, + act=LinearActivation(), + groups=1, + stride=stride, + padding=padding, + bias_attr=bias_attr, + param_attr=param_attr, + shared_biases=True, + layer_attr=None, + filter_size_y=None, + stride_y=None, + padding_y=None, + trans=trans) + + conv_bn = batch_norm_layer( + conv, + act=act, + name=name + nameApx + "_bn", + bias_attr=bias_attr, + param_attr=param_attr_bn, + use_global_stats=False) + + return conv_bn + else: + conv = img_conv_layer( + input, + filter_size=filter_size, + num_filters=num_filters, + name=name + nameApx, + num_channels=channels, + act=act, + groups=1, + stride=stride, + padding=padding, + bias_attr=bias_attr, + param_attr=param_attr, + shared_biases=True, + layer_attr=None, + filter_size_y=None, + stride_y=None, + padding_y=None, + trans=trans) + return conv + + +def generator(noise): + """ + generator generates a sample given noise + """ + param_attr = ParamAttr( + is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.02) + bias_attr = ParamAttr( + is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.0) + + param_attr_bn = ParamAttr( + is_static=is_discriminator_training, initial_mean=1.0, initial_std=0.02) + + h1 = fc_layer( + input=noise, + name="gen_layer_h1", + size=s8 * s8 * gf_dim * 4, + bias_attr=bias_attr, + param_attr=param_attr, + act=LinearActivation()) + + h1_bn = batch_norm_layer( + h1, + act=ReluActivation(), + name="gen_layer_h1_bn", + bias_attr=bias_attr, + param_attr=param_attr_bn, + use_global_stats=False) + + h2_bn = conv_bn( + h1_bn, + channels=gf_dim * 4, + output_x=s8, + num_filters=gf_dim * 2, + imgSize=s4, + stride=2, + name="gen_layer_h2", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True, + trans=True) + + h3_bn = conv_bn( + h2_bn, + channels=gf_dim * 2, + output_x=s4, + num_filters=gf_dim, + imgSize=s2, + stride=2, + name="gen_layer_h3", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True, + trans=True) + + return conv_bn( + h3_bn, + channels=gf_dim, + output_x=s2, + num_filters=c_dim, + imgSize=sample_dim, + stride=2, + name="gen_layer_h4", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=False, + trans=True, + act=TanhActivation()) + + +def discriminator(sample): + """ + discriminator ouputs the probablity of a sample is from generator + or real data. + The output has two dimenstional: dimension 0 is the probablity + of the sample is from generator and dimension 1 is the probabblity + of the sample is from real data. + """ + param_attr = ParamAttr( + is_static=is_generator_training, initial_mean=0.0, initial_std=0.02) + bias_attr = ParamAttr( + is_static=is_generator_training, initial_mean=0.0, initial_std=0.0) + + param_attr_bn = ParamAttr( + is_static=is_generator_training, initial_mean=1.0, initial_std=0.02) + + h0 = conv_bn( + sample, + channels=c_dim, + imgSize=sample_dim, + num_filters=df_dim, + output_x=s2, + stride=2, + name="dis_h0", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=False) + + h1_bn = conv_bn( + h0, + channels=df_dim, + imgSize=s2, + num_filters=df_dim * 2, + output_x=s4, + stride=2, + name="dis_h1", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True) + + h2_bn = conv_bn( + h1_bn, + channels=df_dim * 2, + imgSize=s4, + num_filters=df_dim * 4, + output_x=s8, + stride=2, + name="dis_h2", + param_attr=param_attr, + bias_attr=bias_attr, + param_attr_bn=param_attr_bn, + bn=True) + + return fc_layer( + input=h2_bn, + name="dis_prob", + size=2, + bias_attr=bias_attr, + param_attr=param_attr, + act=SoftmaxActivation()) + + +if is_generator_training: + noise = data_layer(name="noise", size=noise_dim) + sample = generator(noise) + +if is_discriminator_training: + sample = data_layer(name="sample", size=sample_dim * sample_dim * c_dim) + +if is_generator_training or is_discriminator_training: + label = data_layer(name="label", size=1) + prob = discriminator(sample) + cost = cross_entropy(input=prob, label=label) + classification_error_evaluator( + input=prob, label=label, name=mode + '_error') + outputs(cost) + +if is_generator: + noise = data_layer(name="noise", size=noise_dim) + outputs(generator(noise)) diff --git a/gan/gan_trainer.py b/gan/gan_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4a26c230f7a21cc6dd4a3cdb52e32730b1ce73ca --- /dev/null +++ b/gan/gan_trainer.py @@ -0,0 +1,349 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import numpy +import cPickle +import sys, os +from PIL import Image + +from paddle.trainer.config_parser import parse_config +from paddle.trainer.config_parser import logger +import py_paddle.swig_paddle as api +import matplotlib.pyplot as plt + + +def plot2DScatter(data, outputfile): + ''' + Plot the data as a 2D scatter plot and save to outputfile + data needs to be two dimensinoal + ''' + x = data[:, 0] + y = data[:, 1] + logger.info("The mean vector is %s" % numpy.mean(data, 0)) + logger.info("The std vector is %s" % numpy.std(data, 0)) + + heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50) + extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] + + plt.clf() + plt.scatter(x, y) + plt.savefig(outputfile, bbox_inches='tight') + + +def CHECK_EQ(a, b): + assert a == b, "a=%s, b=%s" % (a, b) + + +def copy_shared_parameters(src, dst): + ''' + copy the parameters from src to dst + :param src: the source of the parameters + :type src: GradientMachine + :param dst: the destination of the parameters + :type dst: GradientMachine + ''' + src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())] + src_params = dict([(p.getName(), p) for p in src_params]) + + for i in xrange(dst.getParameterSize()): + dst_param = dst.getParameter(i) + src_param = src_params.get(dst_param.getName(), None) + if src_param is None: + continue + src_value = src_param.getBuf(api.PARAMETER_VALUE) + dst_value = dst_param.getBuf(api.PARAMETER_VALUE) + CHECK_EQ(len(src_value), len(dst_value)) + dst_value.copyFrom(src_value) + dst_param.setValueUpdated() + + +def print_parameters(src): + src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())] + + print "***************" + for p in src_params: + print "Name is %s" % p.getName() + print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray( + ) + + +def load_mnist_data(imageFile): + f = open(imageFile, "rb") + f.read(16) + + # Define number of samples for train/test + if "train" in imageFile: + n = 60000 + else: + n = 10000 + + data = numpy.fromfile(f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)) + data = data / 255.0 * 2.0 - 1.0 + + f.close() + return data.astype('float32') + + +def load_cifar_data(cifar_path): + batch_size = 10000 + data = numpy.zeros((5 * batch_size, 32 * 32 * 3), dtype="float32") + for i in range(1, 6): + file = cifar_path + "/data_batch_" + str(i) + fo = open(file, 'rb') + dict = cPickle.load(fo) + fo.close() + data[(i - 1) * batch_size:(i * batch_size), :] = dict["data"] + + data = data / 255.0 * 2.0 - 1.0 + return data + + +# synthesize 2-D uniform data +def load_uniform_data(): + data = numpy.random.rand(1000000, 2).astype('float32') + return data + + +def merge(images, size): + if images.shape[1] == 28 * 28: + h, w, c = 28, 28, 1 + else: + h, w, c = 32, 32, 3 + img = numpy.zeros((h * size[0], w * size[1], c)) + for idx in xrange(size[0] * size[1]): + i = idx % size[1] + j = idx // size[1] + img[j*h:j*h+h, i*w:i*w+w, :] = \ + ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) + return img.astype('uint8') + + +def save_images(images, path): + merged_img = merge(images, [8, 8]) + if merged_img.shape[2] == 1: + im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB') + else: + im = Image.fromarray(merged_img, mode="RGB") + im.save(path) + + +def get_real_samples(batch_size, data_np): + return data_np[numpy.random.choice( + data_np.shape[0], batch_size, replace=False), :] + + +def get_noise(batch_size, noise_dim): + return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') + + +def get_fake_samples(generator_machine, batch_size, noise): + gen_inputs = api.Arguments.createArguments(1) + gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) + gen_outputs = api.Arguments.createArguments(0) + generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) + fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() + return fake_samples + + +def get_training_loss(training_machine, inputs): + outputs = api.Arguments.createArguments(0) + training_machine.forward(inputs, outputs, api.PASS_TEST) + loss = outputs.getSlotValue(0).copyToNumpyMat() + return numpy.mean(loss) + + +def prepare_discriminator_data_batch_pos(batch_size, data_np): + real_samples = get_real_samples(batch_size, data_np) + labels = numpy.ones(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) + return inputs + + +def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): + fake_samples = get_fake_samples(generator_machine, batch_size, noise) + labels = numpy.zeros(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) + return inputs + + +def prepare_generator_data_batch(batch_size, noise): + label = numpy.ones(batch_size, dtype='int32') + inputs = api.Arguments.createArguments(2) + inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) + inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label)) + return inputs + + +def find(iterable, cond): + for item in iterable: + if cond(item): + return item + return None + + +def get_layer_size(model_conf, layer_name): + layer_conf = find(model_conf.layers, lambda x: x.name == layer_name) + assert layer_conf is not None, "Cannot find '%s' layer" % layer_name + return layer_conf.size + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform") + parser.add_argument( + "--use_gpu", default="1", help="1 means use gpu for training") + parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter") + args = parser.parse_args() + data_source = args.data_source + use_gpu = args.use_gpu + assert data_source in ["mnist", "cifar", "uniform"] + assert use_gpu in ["0", "1"] + + if not os.path.exists("./%s_samples/" % data_source): + os.makedirs("./%s_samples/" % data_source) + + if not os.path.exists("./%s_params/" % data_source): + os.makedirs("./%s_params/" % data_source) + + api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', + '--log_period=100', '--gpu_id=' + args.gpu_id, + '--save_dir=' + "./%s_params/" % data_source) + + if data_source == "uniform": + conf = "gan_conf.py" + num_iter = 10000 + else: + conf = "gan_conf_image.py" + num_iter = 1000 + + gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source) + dis_conf = parse_config(conf, + "mode=discriminator_training,data=" + data_source) + generator_conf = parse_config(conf, "mode=generator,data=" + data_source) + batch_size = dis_conf.opt_config.batch_size + noise_dim = get_layer_size(gen_conf.model_config, "noise") + + if data_source == "mnist": + data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte") + elif data_source == "cifar": + data_np = load_cifar_data("./data/cifar-10-batches-py/") + else: + data_np = load_uniform_data() + + # this creates a gradient machine for discriminator + dis_training_machine = api.GradientMachine.createFromConfigProto( + dis_conf.model_config) + # this create a gradient machine for generator + gen_training_machine = api.GradientMachine.createFromConfigProto( + gen_conf.model_config) + + # generator_machine is used to generate data only, which is used for + # training discriminator + logger.info(str(generator_conf.model_config)) + generator_machine = api.GradientMachine.createFromConfigProto( + generator_conf.model_config) + + dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) + + gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) + + dis_trainer.startTrain() + gen_trainer.startTrain() + + # Sync parameters between networks (GradientMachine) at the beginning + copy_shared_parameters(gen_training_machine, dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) + + # constrain that either discriminator or generator can not be trained + # consecutively more than MAX_strike times + curr_train = "dis" + curr_strike = 0 + MAX_strike = 5 + + for train_pass in xrange(100): + dis_trainer.startTrainPass() + gen_trainer.startTrainPass() + for i in xrange(num_iter): + # Do forward pass in discriminator to get the dis_loss + noise = get_noise(batch_size, noise_dim) + data_batch_dis_pos = prepare_discriminator_data_batch_pos( + batch_size, data_np) + dis_loss_pos = get_training_loss(dis_training_machine, + data_batch_dis_pos) + + data_batch_dis_neg = prepare_discriminator_data_batch_neg( + generator_machine, batch_size, noise) + dis_loss_neg = get_training_loss(dis_training_machine, + data_batch_dis_neg) + + dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 + + # Do forward pass in generator to get the gen_loss + data_batch_gen = prepare_generator_data_batch(batch_size, noise) + gen_loss = get_training_loss(gen_training_machine, data_batch_gen) + + if i % 100 == 0: + print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, + dis_loss_neg) + print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) + + # Decide which network to train based on the training history + # And the relative size of the loss + if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \ + ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): + if curr_train == "dis": + curr_strike += 1 + else: + curr_train = "dis" + curr_strike = 1 + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) + copy_shared_parameters(dis_training_machine, + gen_training_machine) + + else: + if curr_train == "gen": + curr_strike += 1 + else: + curr_train = "gen" + curr_strike = 1 + gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) + # TODO: add API for paddle to allow true parameter sharing between different GradientMachines + # so that we do not need to copy shared parameters. + copy_shared_parameters(gen_training_machine, + dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) + + dis_trainer.finishTrainPass() + gen_trainer.finishTrainPass() + # At the end of each pass, save the generated samples/images + fake_samples = get_fake_samples(generator_machine, batch_size, noise) + if data_source == "uniform": + plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % + (data_source, train_pass)) + else: + save_images(fake_samples, "./%s_samples/train_pass%s.png" % + (data_source, train_pass)) + dis_trainer.finishTrain() + gen_trainer.finishTrain() + + +if __name__ == '__main__': + main()