提交 4147c7f2 编写于 作者: Z zchen0211

gan design modified

上级 275d65b5
'''
GAN implementation, just a demo.
'''
```python
# pd for short, should be more concise.
from paddle.v2 as pd
import numpy as np
import logging
```
# Design for GAN
GAN (General Adversarial Net) is an important model for unsupervised learning and widely used in many areas.
It contains several important machine learning concepts, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth.
In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation.
## The Conditional-GAN might be a class.
This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure:
### DCGAN(object):
which contains everything required to build a GAN model. It provides following member functions methods as API:
### __init__(...):
Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well.
### generator(z, y=None):
Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen.
Returns a generated image.
### discriminator(image):
Given an image, decide if it is from a real source or a fake one.
Returns a 0/1 binary label.
### build_model(self):
build the whole GAN model, define training loss for both generator and discrimator.
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
The original GAN paper.
Borrow this photo from the original DC-GAN paper.
</p>
# Conditional-GAN should be a class.
### Class member function: the initializer.
## Discussion on Engine Functions required to build GAN
- Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly)
- Different optimizers responsible for optimizing different loss.
To be more detailed, we introduce our design of DCGAN as following:
### Class member Function: Initializer
- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth.
- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G.
```python
class DCGAN(object):
def __init__(self, y_dim=None):
......@@ -43,11 +68,16 @@ class DCGAN(object):
self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2]
```
### Class member function: Generator Net
### Class member Function: Generator
- Given a noisy input z, returns a fake image.
- Concatenation, batch-norm, FC operations required;
- Deconv layer required, which is missing now...
```python
def generator(self, z, y = None):
# Generator Net
# input z: the random noise
# input y: input data label (optional)
# output G_im: generated fake images
if not self.y_dim:
z = pd.concat(1, [z, y])
......@@ -64,11 +94,14 @@ def generator(self, z, y = None):
return G_im
```
### Class member function: Discriminator Net
### Class member function: Discriminator
- Given a noisy input z, returns a fake image.
- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required;
```python
def discriminator(self, image):
# input image: either generated images or real ones
# output D_h2: binary logit of the label
# Discriminator Net
D_h0 = pd.conv2d(image, self.D_w0, self.D_b0)
D_h0_bn = pd.batchnorm(h0)
D_h0_relu = pd.lrelu(h0_bn)
......@@ -82,6 +115,9 @@ def discriminator(self, image):
```
### Class member function: Build the model
- Define data readers as placeholders to hold the data;
- Build generator and discriminators;
- Define two training losses for discriminator and generator, respectively.
```python
def build_model(self):
......@@ -92,8 +128,8 @@ def build_model(self):
self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
self.z = pd.data(tf.float32, [None, self.z_size])
# if conditional GAN
if self.y_dim:
# step 1: generate images by generator, classify real/fake images with discriminator
if self.y_dim: # if conditional GAN, includes label
self.G = self.generator(self.z, self.y)
self.D_t = self.discriminator(self.images)
# generated fake images
......@@ -106,6 +142,7 @@ def build_model(self):
self.sampled = self.sampler(self.z)
self.D_f = self.discriminator(self.images)
# step 2: define the two losses
self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
self.d_loss = self.d_loss_real + self.d_loss_fake
......@@ -113,8 +150,13 @@ def build_model(self):
self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie))
```
# Main function for the demo:
## Main function for the demo:
```python
# pd for short, should be more concise.
from paddle.v2 as pd
import numpy as np
import logging
if __name__ == "__main__":
# dcgan
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册