diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md
index eb0bc1c003a13d7d3dccd1d3fda043d1681cce3a..b107f2fc0000e6c9e2c5f78ab00b0cf906e6f251 100644
--- a/doc/design/gan_api.md
+++ b/doc/design/gan_api.md
@@ -1,20 +1,45 @@
-'''
-GAN implementation, just a demo.
-'''
-```python
-# pd for short, should be more concise.
-from paddle.v2 as pd
-import numpy as np
-import logging
-```
+# Design for GAN
+
+GAN (General Adversarial Net) is an important model for unsupervised learning and widely used in many areas.
+
+It contains several important machine learning concepts, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth.
+
+In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation.
+
+## The Conditional-GAN might be a class.
+This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure:
+
+### DCGAN(object):
+which contains everything required to build a GAN model. It provides following member functions methods as API:
+
+### __init__(...):
+Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well.
+
+### generator(z, y=None):
+Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen.
+Returns a generated image.
+
+### discriminator(image):
+Given an image, decide if it is from a real source or a fake one.
+Returns a 0/1 binary label.
+
+### build_model(self):
+build the whole GAN model, define training loss for both generator and discrimator.
-The original GAN paper.
+Borrow this photo from the original DC-GAN paper.
-# Conditional-GAN should be a class.
-### Class member function: the initializer.
+## Discussion on Engine Functions required to build GAN
+- Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly)
+- Different optimizers responsible for optimizing different loss.
+
+To be more detailed, we introduce our design of DCGAN as following:
+
+### Class member Function: Initializer
+- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth.
+- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G.
```python
class DCGAN(object):
def __init__(self, y_dim=None):
@@ -43,11 +68,16 @@ class DCGAN(object):
self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2]
```
-### Class member function: Generator Net
+### Class member Function: Generator
+- Given a noisy input z, returns a fake image.
+- Concatenation, batch-norm, FC operations required;
+- Deconv layer required, which is missing now...
```python
def generator(self, z, y = None):
-
- # Generator Net
+ # input z: the random noise
+ # input y: input data label (optional)
+ # output G_im: generated fake images
+
if not self.y_dim:
z = pd.concat(1, [z, y])
@@ -64,11 +94,14 @@ def generator(self, z, y = None):
return G_im
```
-### Class member function: Discriminator Net
+### Class member function: Discriminator
+- Given a noisy input z, returns a fake image.
+- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required;
```python
def discriminator(self, image):
+ # input image: either generated images or real ones
+ # output D_h2: binary logit of the label
- # Discriminator Net
D_h0 = pd.conv2d(image, self.D_w0, self.D_b0)
D_h0_bn = pd.batchnorm(h0)
D_h0_relu = pd.lrelu(h0_bn)
@@ -82,6 +115,9 @@ def discriminator(self, image):
```
### Class member function: Build the model
+- Define data readers as placeholders to hold the data;
+- Build generator and discriminators;
+- Define two training losses for discriminator and generator, respectively.
```python
def build_model(self):
@@ -92,8 +128,8 @@ def build_model(self):
self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
self.z = pd.data(tf.float32, [None, self.z_size])
- # if conditional GAN
- if self.y_dim:
+ # step 1: generate images by generator, classify real/fake images with discriminator
+ if self.y_dim: # if conditional GAN, includes label
self.G = self.generator(self.z, self.y)
self.D_t = self.discriminator(self.images)
# generated fake images
@@ -106,6 +142,7 @@ def build_model(self):
self.sampled = self.sampler(self.z)
self.D_f = self.discriminator(self.images)
+ # step 2: define the two losses
self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
self.d_loss = self.d_loss_real + self.d_loss_fake
@@ -113,8 +150,13 @@ def build_model(self):
self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie))
```
-# Main function for the demo:
+## Main function for the demo:
```python
+# pd for short, should be more concise.
+from paddle.v2 as pd
+import numpy as np
+import logging
+
if __name__ == "__main__":
# dcgan