From 4147c7f22836fe7ae7b0c6e616adaba0bbfe3b3a Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 4 Oct 2017 15:52:23 -0700 Subject: [PATCH] gan design modified --- doc/design/gan_api.md | 82 ++++++++++++++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md index eb0bc1c003a..b107f2fc000 100644 --- a/doc/design/gan_api.md +++ b/doc/design/gan_api.md @@ -1,20 +1,45 @@ -''' -GAN implementation, just a demo. -''' -```python -# pd for short, should be more concise. -from paddle.v2 as pd -import numpy as np -import logging -``` +# Design for GAN + +GAN (General Adversarial Net) is an important model for unsupervised learning and widely used in many areas. + +It contains several important machine learning concepts, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth. + +In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation. + +## The Conditional-GAN might be a class. +This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure: + +### DCGAN(object): +which contains everything required to build a GAN model. It provides following member functions methods as API: + +### __init__(...): +Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well. + +### generator(z, y=None): +Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen. +Returns a generated image. + +### discriminator(image): +Given an image, decide if it is from a real source or a fake one. +Returns a 0/1 binary label. + +### build_model(self): +build the whole GAN model, define training loss for both generator and discrimator.


-The original GAN paper. +Borrow this photo from the original DC-GAN paper.

-# Conditional-GAN should be a class. -### Class member function: the initializer. +## Discussion on Engine Functions required to build GAN +- Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly) +- Different optimizers responsible for optimizing different loss. + +To be more detailed, we introduce our design of DCGAN as following: + +### Class member Function: Initializer +- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth. +- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G. ```python class DCGAN(object): def __init__(self, y_dim=None): @@ -43,11 +68,16 @@ class DCGAN(object): self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2] ``` -### Class member function: Generator Net +### Class member Function: Generator +- Given a noisy input z, returns a fake image. +- Concatenation, batch-norm, FC operations required; +- Deconv layer required, which is missing now... ```python def generator(self, z, y = None): - - # Generator Net + # input z: the random noise + # input y: input data label (optional) + # output G_im: generated fake images + if not self.y_dim: z = pd.concat(1, [z, y]) @@ -64,11 +94,14 @@ def generator(self, z, y = None): return G_im ``` -### Class member function: Discriminator Net +### Class member function: Discriminator +- Given a noisy input z, returns a fake image. +- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required; ```python def discriminator(self, image): + # input image: either generated images or real ones + # output D_h2: binary logit of the label - # Discriminator Net D_h0 = pd.conv2d(image, self.D_w0, self.D_b0) D_h0_bn = pd.batchnorm(h0) D_h0_relu = pd.lrelu(h0_bn) @@ -82,6 +115,9 @@ def discriminator(self, image): ``` ### Class member function: Build the model +- Define data readers as placeholders to hold the data; +- Build generator and discriminators; +- Define two training losses for discriminator and generator, respectively. ```python def build_model(self): @@ -92,8 +128,8 @@ def build_model(self): self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) self.z = pd.data(tf.float32, [None, self.z_size]) - # if conditional GAN - if self.y_dim: + # step 1: generate images by generator, classify real/fake images with discriminator + if self.y_dim: # if conditional GAN, includes label self.G = self.generator(self.z, self.y) self.D_t = self.discriminator(self.images) # generated fake images @@ -106,6 +142,7 @@ def build_model(self): self.sampled = self.sampler(self.z) self.D_f = self.discriminator(self.images) + # step 2: define the two losses self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size)) self.d_loss = self.d_loss_real + self.d_loss_fake @@ -113,8 +150,13 @@ def build_model(self): self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie)) ``` -# Main function for the demo: +## Main function for the demo: ```python +# pd for short, should be more concise. +from paddle.v2 as pd +import numpy as np +import logging + if __name__ == "__main__": # dcgan -- GitLab