提交 4767fb67 编写于 作者: Z zchen0211

gan api modified

上级 4238b9b9
...@@ -139,10 +139,10 @@ class DCGAN(object): ...@@ -139,10 +139,10 @@ class DCGAN(object):
- Define data readers as placeholders to hold the data; - Define data readers as placeholders to hold the data;
- Build generator and discriminators; - Build generator and discriminators;
- Define two training losses for discriminator and generator, respectively. - Define two training losses for discriminator and generator, respectively.
If we have execution dependency engine to back-trace all tensors, the module building our GAN model will be like this:
```python ```python
class DCGAN(object): class DCGAN(object):
def build_model(self): def build_model(self):
# input data
if self.y_dim: if self.y_dim:
self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) self.y = pd.data(pd.float32, [self.batch_size, self.y_dim])
self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
...@@ -155,7 +155,7 @@ class DCGAN(object): ...@@ -155,7 +155,7 @@ class DCGAN(object):
self.D_t = self.discriminator(self.images) self.D_t = self.discriminator(self.images)
# generated fake images # generated fake images
self.sampled = self.sampler(self.z, self.y) self.sampled = self.sampler(self.z, self.y)
self.D_f = self.discriminator(self.images) self.D_f = self.discriminator(self.G)
else: # original version of GAN else: # original version of GAN
self.G = self.generator(self.z) self.G = self.generator(self.z)
self.D_t = self.discriminator(self.images) self.D_t = self.discriminator(self.images)
...@@ -171,6 +171,44 @@ class DCGAN(object): ...@@ -171,6 +171,44 @@ class DCGAN(object):
self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie)) self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie))
``` ```
If we do not have dependency engine but blocks, the module building our GAN model will be like this:
```python
class DCGAN(object):
def build_model(self, default_block):
# input data in the default block
if self.y_dim:
self.y = pd.data(pd.float32, [self.batch_size, self.y_dim])
self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
# self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
self.z = pd.data(tf.float32, [None, self.z_size])
# step 1: generate images by generator, classify real/fake images with discriminator
with pd.default_block().g_block():
if self.y_dim: # if conditional GAN, includes label
self.G = self.generator(self.z, self.y)
self.D_g = self.discriminator(self.G, self.y)
else: # original version of GAN
self.G = self.generator(self.z)
self.D_g = self.discriminator(self.G, self.y)
self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_g, np.ones(self.batch_szie))
with pd.default_block().d_block():
if self.y_dim: # if conditional GAN, includes label
self.D_t = self.discriminator(self.images, self.y)
self.D_f = self.discriminator(self.G, self.y)
else: # original version of GAN
self.D_t = self.discriminator(self.images)
self.D_f = self.discriminator(self.G)
# step 2: define the two losses
self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
self.d_loss = self.d_loss_real + self.d_loss_fake
```
Some small confusion and problems with this design:
- D\_g and D\_f are actually the same thing, but has to be written twice;
- Requires ability to create a block anytime, rather than in if-else or rnn only;
## Main function for the demo: ## Main function for the demo:
Generally, the user of GAN just need to the following things: Generally, the user of GAN just need to the following things:
- Define an object as DCGAN class; - Define an object as DCGAN class;
...@@ -183,9 +221,10 @@ import numpy as np ...@@ -183,9 +221,10 @@ import numpy as np
import logging import logging
if __name__ == "__main__": if __name__ == "__main__":
# dcgan # dcgan class in the default graph/block
with pd.block() as def_block:
dcgan = DCGAN() dcgan = DCGAN()
dcgan.build_model() dcgan.build_model(def_block)
# load mnist data # load mnist data
data_X, data_y = self.load_mnist() data_X, data_y = self.load_mnist()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册