From 4767fb6719694ee400d3a6c9344aa21edde8bd36 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 9 Oct 2017 16:05:14 -0700 Subject: [PATCH] gan api modified --- doc/design/gan_api.md | 67 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md index 4fcff8b70a3..77c867bac7e 100644 --- a/doc/design/gan_api.md +++ b/doc/design/gan_api.md @@ -139,10 +139,10 @@ class DCGAN(object): - Define data readers as placeholders to hold the data; - Build generator and discriminators; - Define two training losses for discriminator and generator, respectively. +If we have execution dependency engine to back-trace all tensors, the module building our GAN model will be like this: ```python class DCGAN(object): def build_model(self): - # input data if self.y_dim: self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) @@ -151,17 +151,17 @@ class DCGAN(object): # step 1: generate images by generator, classify real/fake images with discriminator if self.y_dim: # if conditional GAN, includes label - self.G = self.generator(self.z, self.y) - self.D_t = self.discriminator(self.images) - # generated fake images - self.sampled = self.sampler(self.z, self.y) - self.D_f = self.discriminator(self.images) + self.G = self.generator(self.z, self.y) + self.D_t = self.discriminator(self.images) + # generated fake images + self.sampled = self.sampler(self.z, self.y) + self.D_f = self.discriminator(self.G) else: # original version of GAN - self.G = self.generator(self.z) - self.D_t = self.discriminator(self.images) - # generate fake images - self.sampled = self.sampler(self.z) - self.D_f = self.discriminator(self.images) + self.G = self.generator(self.z) + self.D_t = self.discriminator(self.images) + # generate fake images + self.sampled = self.sampler(self.z) + self.D_f = self.discriminator(self.images) # step 2: define the two losses self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) @@ -171,6 +171,44 @@ class DCGAN(object): self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie)) ``` +If we do not have dependency engine but blocks, the module building our GAN model will be like this: +```python +class DCGAN(object): + def build_model(self, default_block): + # input data in the default block + if self.y_dim: + self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) + self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + # self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.z = pd.data(tf.float32, [None, self.z_size]) + + # step 1: generate images by generator, classify real/fake images with discriminator + with pd.default_block().g_block(): + if self.y_dim: # if conditional GAN, includes label + self.G = self.generator(self.z, self.y) + self.D_g = self.discriminator(self.G, self.y) + else: # original version of GAN + self.G = self.generator(self.z) + self.D_g = self.discriminator(self.G, self.y) + self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_g, np.ones(self.batch_szie)) + + with pd.default_block().d_block(): + if self.y_dim: # if conditional GAN, includes label + self.D_t = self.discriminator(self.images, self.y) + self.D_f = self.discriminator(self.G, self.y) + else: # original version of GAN + self.D_t = self.discriminator(self.images) + self.D_f = self.discriminator(self.G) + + # step 2: define the two losses + self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) + self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size)) + self.d_loss = self.d_loss_real + self.d_loss_fake +``` +Some small confusion and problems with this design: +- D\_g and D\_f are actually the same thing, but has to be written twice; +- Requires ability to create a block anytime, rather than in if-else or rnn only; + ## Main function for the demo: Generally, the user of GAN just need to the following things: - Define an object as DCGAN class; @@ -183,9 +221,10 @@ import numpy as np import logging if __name__ == "__main__": - # dcgan - dcgan = DCGAN() - dcgan.build_model() + # dcgan class in the default graph/block + with pd.block() as def_block: + dcgan = DCGAN() + dcgan.build_model(def_block) # load mnist data data_X, data_y = self.load_mnist() -- GitLab