gan_api.md 9.3 KB
Newer Older
Z
zchen0211 已提交
1 2
# Design for GAN

Z
gan_api  
zchen0211 已提交
3
GAN (General Adversarial Net [https://arxiv.org/abs/1406.2661]) is an important model for unsupervised learning and widely used in many areas. 
Z
zchen0211 已提交
4

Z
gan_api  
zchen0211 已提交
5
It applies several important concepts in machine learning system design, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth.
Z
zchen0211 已提交
6

Z
gan_api  
zchen0211 已提交
7
In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN (Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks [https://arxiv.org/abs/1511.06434]) as an example due to its good performance on image generation.
Z
zchen0211 已提交
8

Z
gan api  
zchen0211 已提交
9 10 11 12
| important building blocks | People in Charge  | Required |
|---------------------------|-------------------|----------|
| convolution 2d (done)     | Chengduo          | Y        |
| cudnn conv 2d (missing)   | Chengduo          | N        |
Z
gan api  
zchen0211 已提交
13 14
| deconv 2d (missing)       | Zhuoyuan, Zhihong | Y        |
| cudnn deconv 2d (missing) | Zhuoyuan, Zhihong | N        |
Z
gan api  
zchen0211 已提交
15 16 17
| batch norm (missing)      | Zhuoyuan, Jiayi   | Y        |
| cudnn batch norm (missing)| Zhuoyuan, Jiayi   | N        |
| max-pooling (done)        | ?                 | Y        |
Z
gan api  
zchen0211 已提交
18
| cudnn-max-pool (missing)  | Chengduo          | Y        |
Z
gan api  
zchen0211 已提交
19 20 21 22
| fc (done)                 | ?                 | Y        |
| softmax loss (done)       | ?                 | Y        |
| reshape op (done)         | ?                 | Y        |
| Dependency Engine (done)  | Jiayi             | Y *      |
Z
gan api  
zchen0211 已提交
23
| Python API (done)         | Longfei, Jiayi    | Y *      |
Z
gan api  
zchen0211 已提交
24
| Executor (done)           | Tony              | Y *      |
Z
gan api  
zchen0211 已提交
25
| Multi optimizer (woking)  | Longfei           | Y *      |
Z
gan api  
zchen0211 已提交
26
| Optimizer with any para   | ?                 | Y *      |
Z
gan api  
zchen0211 已提交
27 28
| Concat op (done)          | ?                 | N (Cond) |
| Repmat op (done)          | ?                 | N (Cond) |
Z
new gan  
zchen0211 已提交
29 30


Z
zchen0211 已提交
31 32 33 34 35
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
Borrow this photo from the original DC-GAN paper.
</p>

Z
zchen0211 已提交
36 37 38
## The Conditional-GAN might be a class. 
This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure:

Z
zchen0211 已提交
39
- DCGAN(object): which contains everything required to build a GAN model. It provides following member functions methods as API:
Z
zchen0211 已提交
40

Z
zchen0211 已提交
41
- __init__(...): Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well.
Z
zchen0211 已提交
42

Z
zchen0211 已提交
43
- generator(z, y=None): Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen.
Z
zchen0211 已提交
44 45
Returns a generated image.

Z
zchen0211 已提交
46
- discriminator(image):
Z
zchen0211 已提交
47 48 49
Given an image, decide if it is from a real source or a fake one. 
Returns a 0/1 binary label.

Z
zchen0211 已提交
50
- build_model(self):
Z
zchen0211 已提交
51
build the whole GAN model, define training loss for both generator and discrimator.
Z
zchen0211 已提交
52

Z
zchen0211 已提交
53
## Discussion on Engine Functions required to build GAN
Z
gan_api  
zchen0211 已提交
54
- Trace the tensor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly)
Z
zchen0211 已提交
55 56 57 58 59 60 61
- Different optimizers responsible for optimizing different loss.

To be more detailed, we introduce our design of DCGAN as following:

### Class member Function: Initializer
- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth.
- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G.
Z
gan api  
zchen0211 已提交
62
```python
Z
zchen0211 已提交
63 64 65 66 67 68 69 70 71
class DCGAN(object):
  def __init__(self, y_dim=None):
  
    # hyper parameters  
    self.y_dim = y_dim # conditional gan or not
    self.batch_size = 100
    self.z_dim = z_dim # input noise dimension

    # define parameters of discriminators
Z
zchen0211 已提交
72
    self.D_W0 = pd.Variable(shape=[3,3, 1, 128], data=pd.gaussian_normal_randomizer())
Z
gan api  
zchen0211 已提交
73
    self.D_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
Z
zchen0211 已提交
74 75 76 77
    self.D_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
    self.D_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
    self.D_W2 = pd.Varialble(np.random.rand(128, 1))
    self.D_b2 = pd.Variable(np.zeros(128))
Z
gan api  
zchen0211 已提交
78
    self.theta_D = [self.D_W0, self.D_b0, self.D_W1, self.D_b1, self.D_W2, self.D_b2]
Z
zchen0211 已提交
79 80

    # define parameters of generators
Z
gan api  
zchen0211 已提交
81 82
    self.G_W0 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
    self.G_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
Z
zchen0211 已提交
83 84 85 86
    self.G_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
    self.G_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
    self.G_W2 = pd.Varialble(np.random.rand(128, 1))
    self.G_b2 = pd.Variable(np.zeros(128))
Z
gan api  
zchen0211 已提交
87 88
    self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2]
```
Z
zchen0211 已提交
89

Z
zchen0211 已提交
90 91 92 93
### Class member Function: Generator
- Given a noisy input z, returns a fake image.
- Concatenation, batch-norm, FC operations required;
- Deconv layer required, which is missing now...
Z
gan api  
zchen0211 已提交
94
```python
Z
gan api  
zchen0211 已提交
95 96
class DCGAN(object):
  def generator(self, z, y = None):
Z
zchen0211 已提交
97 98 99 100
    # input z: the random noise
    # input y: input data label (optional)
    # output G_im: generated fake images
    
Z
zchen0211 已提交
101
    if not self.y_dim:
Z
gan api  
zchen0211 已提交
102
      z = pd.layer.concat(1, [z, y])
Z
zchen0211 已提交
103
      
Z
gan api  
zchen0211 已提交
104 105 106
    G_h0 = pd.layer.fc(z, self.G_w0, self.G_b0)
    G_h0_bn = pd.layer.batch_norm(G_h0)
    G_h0_relu = pd.layer.relu(G_h0_bn)
Z
zchen0211 已提交
107
    
Z
gan api  
zchen0211 已提交
108 109 110
    G_h1 = pd.layer.deconv(G_h0_relu, self.G_w1, self.G_b1)
    G_h1_bn = pd.layer.batch_norm(G_h1)
    G_h1_relu = pd.layer.relu(G_h1_bn)
Z
zchen0211 已提交
111
    
Z
gan api  
zchen0211 已提交
112 113
    G_h2 = pd.layer.deconv(G_h1_relu, self.G_W2, self.G_b2))
    G_im = pd.layer.tanh(G_im)
Z
zchen0211 已提交
114
    return G_im
Z
gan api  
zchen0211 已提交
115 116
```

Z
zchen0211 已提交
117 118 119
### Class member function: Discriminator
- Given a noisy input z, returns a fake image.
- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required;
Z
gan api  
zchen0211 已提交
120
```python
Z
gan api  
zchen0211 已提交
121 122
class DCGAN(object):
  def discriminator(self, image):
Z
zchen0211 已提交
123 124
    # input image: either generated images or real ones
    # output D_h2: binary logit of the label
Z
zchen0211 已提交
125

Z
gan api  
zchen0211 已提交
126 127 128
    D_h0 = pd.layer.conv2d(image, w=self.D_w0, b=self.D_b0)
    D_h0_bn = pd.layer.batchnorm(h0)
    D_h0_relu = pd.layer.lrelu(h0_bn)
Z
zchen0211 已提交
129
    
Z
gan api  
zchen0211 已提交
130 131 132
    D_h1 = pd.layer.conv2d(D_h0_relu, w=self.D_w1, b=self.D_b1)
    D_h1_bn = pd.layer.batchnorm(D_h1)
    D_h1_relu = pd.layer.lrelu(D_h1_bn)
Z
zchen0211 已提交
133
    
Z
gan api  
zchen0211 已提交
134
    D_h2 = pd.layer.fc(D_h1_relu, w=self.D_w2, b=self.D_b2)
Z
zchen0211 已提交
135
    return D_h2
Z
gan api  
zchen0211 已提交
136
```
Z
zchen0211 已提交
137 138

### Class member function: Build the model
Z
zchen0211 已提交
139 140 141
- Define data readers as placeholders to hold the data;
- Build generator and discriminators;
- Define two training losses for discriminator and generator, respectively. 
Z
gan api  
zchen0211 已提交
142
```python
Z
gan api  
zchen0211 已提交
143 144
class DCGAN(object):
  def build_model(self):
Z
zchen0211 已提交
145 146 147 148 149 150 151
    # input data
    if self.y_dim:
        self.y = pd.data(pd.float32, [self.batch_size, self.y_dim])
    self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
    self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
    self.z = pd.data(tf.float32, [None, self.z_size])
    
Z
zchen0211 已提交
152 153
    # step 1: generate images by generator, classify real/fake images with discriminator
    if self.y_dim: # if conditional GAN, includes label
Z
zchen0211 已提交
154 155 156 157 158 159 160 161 162 163 164 165
      self.G = self.generator(self.z, self.y)
      self.D_t = self.discriminator(self.images)
      # generated fake images
      self.sampled = self.sampler(self.z, self.y)
      self.D_f = self.discriminator(self.images)
    else: # original version of GAN
      self.G = self.generator(self.z)
      self.D_t = self.discriminator(self.images)
      # generate fake images
      self.sampled = self.sampler(self.z)
      self.D_f = self.discriminator(self.images)
    
Z
zchen0211 已提交
166
    # step 2: define the two losses
Z
zchen0211 已提交
167 168 169 170 171
    self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
    self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
    self.d_loss = self.d_loss_real + self.d_loss_fake
    
    self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie))
Z
gan api  
zchen0211 已提交
172
```
Z
zchen0211 已提交
173

Z
zchen0211 已提交
174
## Main function for the demo:
Z
zchen0211 已提交
175 176 177 178
Generally, the user of GAN just need to the following things:
- Define an object as DCGAN class;
- Build the DCGAN model;
- Specify two optimizers for two different losses with respect to different parameters.
Z
gan api  
zchen0211 已提交
179
```python
Z
zchen0211 已提交
180 181 182 183 184
# pd for short, should be more concise.
from paddle.v2 as pd
import numpy as np
import logging

Z
zchen0211 已提交
185 186 187 188 189 190 191 192 193
if __name__ == "__main__":
    # dcgan
    dcgan = DCGAN()
    dcgan.build_model()

    # load mnist data
    data_X, data_y = self.load_mnist()
    
    # Two subgraphs required!!!
Z
zchen0211 已提交
194 195
    d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.d_loss, dcgan.theta_D)
    g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.g_loss, dcgan.theta_G)
Z
zchen0211 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209

    # executor
    sess = pd.executor()
    
    # training
    for epoch in xrange(10000):
      for batch_id in range(N / batch_size):
        idx = ...
        # sample a batch
        batch_im, batch_label = data_X[idx:idx+batch_size], data_y[idx:idx+batch_size]
        # sample z
        batch_z = np.random.uniform(-1., 1., [batch_size, z_dim])

        if batch_id % 2 == 0:
Z
zchen0211 已提交
210
          sess.run(d_optim, 
Z
zchen0211 已提交
211 212 213 214
                   feed_dict = {dcgan.images: batch_im,
                                dcgan.y: batch_label,
                                dcgan.z: batch_z})
        else:
Z
zchen0211 已提交
215
          sess.run(g_optim,
Z
zchen0211 已提交
216
                   feed_dict = {dcgan.z: batch_z})
Z
gan api  
zchen0211 已提交
217
```