gan_api.md 4.9 KB
Newer Older
Z
zchen0211 已提交
1 2 3
'''
GAN implementation, just a demo.
'''
Z
gan api  
zchen0211 已提交
4
```python
Z
zchen0211 已提交
5 6 7 8
# pd for short, should be more concise.
from paddle.v2 as pd
import numpy as np
import logging
Z
gan api  
zchen0211 已提交
9
```
Z
zchen0211 已提交
10 11 12 13 14 15

<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
The original GAN paper.
</p>

Z
zchen0211 已提交
16 17
# Conditional-GAN should be a class. 
### Class member function: the initializer.
Z
gan api  
zchen0211 已提交
18
```python
Z
zchen0211 已提交
19 20 21 22 23 24 25 26 27
class DCGAN(object):
  def __init__(self, y_dim=None):
  
    # hyper parameters  
    self.y_dim = y_dim # conditional gan or not
    self.batch_size = 100
    self.z_dim = z_dim # input noise dimension

    # define parameters of discriminators
Z
zchen0211 已提交
28
    self.D_W0 = pd.Variable(shape=[3,3, 1, 128], data=pd.gaussian_normal_randomizer())
Z
gan api  
zchen0211 已提交
29
    self.D_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
Z
zchen0211 已提交
30 31 32 33
    self.D_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
    self.D_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
    self.D_W2 = pd.Varialble(np.random.rand(128, 1))
    self.D_b2 = pd.Variable(np.zeros(128))
Z
gan api  
zchen0211 已提交
34
    self.theta_D = [self.D_W0, self.D_b0, self.D_W1, self.D_b1, self.D_W2, self.D_b2]
Z
zchen0211 已提交
35 36

    # define parameters of generators
Z
gan api  
zchen0211 已提交
37 38
    self.G_W0 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
    self.G_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
Z
zchen0211 已提交
39 40 41 42
    self.G_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
    self.G_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a  numpy data
    self.G_W2 = pd.Varialble(np.random.rand(128, 1))
    self.G_b2 = pd.Variable(np.zeros(128))
Z
gan api  
zchen0211 已提交
43 44
    self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2]
```
Z
zchen0211 已提交
45 46

### Class member function: Generator Net
Z
gan api  
zchen0211 已提交
47
```python
Z
zchen0211 已提交
48 49 50 51 52 53 54 55 56 57
def generator(self, z, y = None):

    # Generator Net
    if not self.y_dim:
      z = pd.concat(1, [z, y])
      
    G_h0 = pd.fc(z, self.G_w0, self.G_b0)
    G_h0_bn = pd.batch_norm(G_h0)
    G_h0_relu = pd.relu(G_h0_bn)
    
Z
zchen0211 已提交
58
    G_h1 = pd.deconv(G_h0_relu, self.G_w1, self.G_b1)
Z
zchen0211 已提交
59 60 61 62 63 64
    G_h1_bn = pd.batch_norm(G_h1)
    G_h1_relu = pd.relu(G_h1_bn)
    
    G_h2 = pd.deconv(G_h1_relu, self.G_W2, self.G_b2))
    G_im = pd.tanh(G_im)
    return G_im
Z
gan api  
zchen0211 已提交
65 66
```

Z
zchen0211 已提交
67
### Class member function: Discriminator Net
Z
gan api  
zchen0211 已提交
68
```python
Z
zchen0211 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81
def discriminator(self, image):

    # Discriminator Net
    D_h0 = pd.conv2d(image, self.D_w0, self.D_b0)
    D_h0_bn = pd.batchnorm(h0)
    D_h0_relu = pd.lrelu(h0_bn)
    
    D_h1 = pd.conv2d(D_h0_relu, self.D_w1, self.D_b1)
    D_h1_bn = pd.batchnorm(D_h1)
    D_h1_relu = pd.lrelu(D_h1_bn)
    
    D_h2 = pd.fc(D_h1_relu, self.D_w2, self.D_b2)
    return D_h2
Z
gan api  
zchen0211 已提交
82
```
Z
zchen0211 已提交
83 84

### Class member function: Build the model
Z
gan api  
zchen0211 已提交
85
```python
Z
zchen0211 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
def build_model(self):

    # input data
    if self.y_dim:
        self.y = pd.data(pd.float32, [self.batch_size, self.y_dim])
    self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
    self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
    self.z = pd.data(tf.float32, [None, self.z_size])
    
    # if conditional GAN
    if self.y_dim:
      self.G = self.generator(self.z, self.y)
      self.D_t = self.discriminator(self.images)
      # generated fake images
      self.sampled = self.sampler(self.z, self.y)
      self.D_f = self.discriminator(self.images)
    else: # original version of GAN
      self.G = self.generator(self.z)
      self.D_t = self.discriminator(self.images)
      # generate fake images
      self.sampled = self.sampler(self.z)
      self.D_f = self.discriminator(self.images)
    
    self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
    self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
    self.d_loss = self.d_loss_real + self.d_loss_fake
    
    self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie))
Z
gan api  
zchen0211 已提交
114
```
Z
zchen0211 已提交
115 116

# Main function for the demo:
Z
gan api  
zchen0211 已提交
117
```python
Z
zchen0211 已提交
118 119 120 121 122 123 124 125 126 127
if __name__ == "__main__":

    # dcgan
    dcgan = DCGAN()
    dcgan.build_model()

    # load mnist data
    data_X, data_y = self.load_mnist()
    
    # Two subgraphs required!!!
Z
gan api  
zchen0211 已提交
128
    d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.d_loss, )
Z
zchen0211 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.g_loss)

    # executor
    sess = pd.executor()
    
    # training
    for epoch in xrange(10000):
      for batch_id in range(N / batch_size):
        idx = ...
        # sample a batch
        batch_im, batch_label = data_X[idx:idx+batch_size], data_y[idx:idx+batch_size]
        # sample z
        batch_z = np.random.uniform(-1., 1., [batch_size, z_dim])

        if batch_id % 2 == 0:
Z
gan api  
zchen0211 已提交
144
          sess.eval(d_optim, 
Z
zchen0211 已提交
145 146 147 148
                   feed_dict = {dcgan.images: batch_im,
                                dcgan.y: batch_label,
                                dcgan.z: batch_z})
        else:
Z
gan api  
zchen0211 已提交
149
          sess.eval(g_optim,
Z
zchen0211 已提交
150
                   feed_dict = {dcgan.z: batch_z})
Z
gan api  
zchen0211 已提交
151
```