index_en.md 7.0 KB
Newer Older
W
wangyang59 已提交
1 2
# Generative Adversarial Networks (GAN) 

3
This demo implements GAN training described in the original [GAN paper](https://arxiv.org/abs/1406.2661) and deep convolutional generative adversarial networks [DCGAN paper](https://arxiv.org/abs/1511.06434).
W
wangyang59 已提交
4

W
wangyang59 已提交
5
The high-level structure of GAN is shown in Figure. 1 below. It is composed of two major parts: a generator and a discriminator, both of which are based on neural networks. The generator takes in some kind of noise with a known distribution and transforms it into an image. The discriminator takes in an image and determines whether it is artificially generated by the generator or a real image. So the generator and the discriminator are in a competitive game in which generator is trying to generate image to look as real as possible to fool the discriminator, while the discriminator is trying to distinguish between real and fake images. 
W
wangyang59 已提交
6

W
wangyang59 已提交
7 8 9 10 11 12
<p align="center">
    <img src="./gan.png" width="500" height="300"> 
</p>
<p align="center">
    Figure 1. GAN-Model-Structure [figure credit](https://ishmaelbelghazi.github.io/ALI/)
</p>
W
wangyang59 已提交
13

W
wangyang59 已提交
14
The generator and discriminator take turn to be trained using SGD. The objective function of the generator is for its generated images being classified as real by the discriminator, and the objective function of the discriminator is to correctly classify real and fake images. When the GAN model is trained to converge to the equilibrium state, the generator will transform the given noise distribution to the distribution of real images, and the discriminator will not be able to distinguish between real and fake images at all. 
W
wangyang59 已提交
15

W
wangyang59 已提交
16 17
## Implementation of GAN Model Structure
Since GAN model involves multiple neural networks, it requires to use paddle python API. So the code walk-through below can also partially serve as an introduction to the usage of Paddle Python API.
W
wangyang59 已提交
18

19
There are three networks defined in gan_conf.py, namely **generator_training**, **discriminator_training** and **generator**. The relationship to the model structure we defined above is that **discriminator_training** is the discriminator, **generator** is the generator, and the **generator_training** combined the generator and discriminator since training generator would require the discriminator to provide loss function. This relationship is described in the following code:
W
wangyang59 已提交
20 21 22 23
```python
if is_generator_training:
    noise = data_layer(name="noise", size=noise_dim)
    sample = generator(noise)
W
wangyang59 已提交
24

W
wangyang59 已提交
25 26
if is_discriminator_training:
    sample = data_layer(name="sample", size=sample_dim)
W
wangyang59 已提交
27

W
wangyang59 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40
if is_generator_training or is_discriminator_training:
    label = data_layer(name="label", size=1)
    prob = discriminator(sample)
    cost = cross_entropy(input=prob, label=label)
    classification_error_evaluator(
        input=prob, label=label, name=mode + '_error')
    outputs(cost)

if is_generator:
    noise = data_layer(name="noise", size=noise_dim)
    outputs(generator(noise))
```

41
In order to train the networks defined in gan_conf.py, one first needs to initialize a Paddle environment, parse the config, create GradientMachine from the config and create trainer from GradientMachine as done in the code chunk below:
W
wangyang59 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
```python
import py_paddle.swig_paddle as api
# init paddle environment
api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10',
               '--log_period=100', '--gpu_id=' + args.gpu_id,
               '--save_dir=' + "./%s_params/" % data_source)

# Parse config
gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source)

# Create GradientMachine
dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config)
gen_training_machine = api.GradientMachine.createFromConfigProto(
gen_conf.model_config)
generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config)

# Create trainer
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
```

67
In order to balance the strength between generator and discriminator, we schedule to train whichever one is performing worse by comparing their loss function value. The loss function value can be calculated by a forward pass through the GradientMachine.
W
wangyang59 已提交
68 69 70 71 72 73 74 75
```python
def get_training_loss(training_machine, inputs):
    outputs = api.Arguments.createArguments(0)
    training_machine.forward(inputs, outputs, api.PASS_TEST)
    loss = outputs.getSlotValue(0).copyToNumpyMat()
    return numpy.mean(loss)
```

76
After training one network, one needs to sync the new parameters to the other networks. The code below demonstrates one example of such use case:
W
wangyang59 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90
```python
# Train the gen_training
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)

# Copy the parameters from gen_training to dis_training and generator
copy_shared_parameters(gen_training_machine,
dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine)
```


## A Toy Example 
With the infrastructure explained above, we can now walk you through a toy example of generating two dimensional uniform distribution using 10 dimensional Gaussian noise. 

91
The Gaussian noises are generated using the code below:
W
wangyang59 已提交
92 93 94 95 96
```python
def get_noise(batch_size, noise_dim):
    return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
```

97
The real samples (2-D uniform) are generated using the code below:
W
wangyang59 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
```python
# synthesize 2-D uniform data in gan_trainer.py:114
def load_uniform_data():
    data = numpy.random.rand(1000000, 2).astype('float32')
    return data
```

The generator and discriminator network are built using fully-connected layer and batch_norm layer, and are defined in gan_conf.py. 

To train the GAN model, one can use the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu).  
```bash
$python gan_trainer.py -d uniform --useGpu 1
```
The generated samples can be found in ./uniform_samples/ and one example is shown below as Figure 2. One can see that it roughly recovers the 2D uniform distribution. 

113
<p align="center">
W
wangyang59 已提交
114
    <img src="./uniform_sample.png" width="300" height="300"> 
115 116 117 118
</p>
<p align="center">
    Figure 2. Uniform Sample
</p>
W
wangyang59 已提交
119 120 121

## MNIST Example
### Data preparation
122
To download the MNIST data, one can use the following commands:
W
wangyang59 已提交
123 124 125 126 127 128 129 130 131
```bash
$cd data/
$./get_mnist_data.sh
```

### Model description
Following the DC-Gan paper (https://arxiv.org/abs/1511.06434), we use convolution/convolution-transpose layer in the discriminator/generator network to better deal with images. The details of the network structures are defined in gan_conf_image.py. 

### Training the model
132
To train the GAN model on mnist data, one can use the following command:
W
wangyang59 已提交
133 134 135 136
```bash
$python gan_trainer.py -d mnist --useGpu 1
```
The generated sample images can be found at ./mnist_samples/ and one example is shown below as Figure 3. 
W
wangyang59 已提交
137 138 139 140 141 142
<p align="center">
    <img src="./mnist_sample.png" width="300" height="300"> 
</p>
<p align="center">
    Figure 3. MNIST Sample
</p>