提交 79c8bb9e 编写于 作者: Z zchen0211

gan design new version

上级 4147c7f2
......@@ -6,6 +6,11 @@ It contains several important machine learning concepts, including building and
In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation.
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
Borrow this photo from the original DC-GAN paper.
</p>
## The Conditional-GAN might be a class.
This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure:
......@@ -26,11 +31,6 @@ Returns a 0/1 binary label.
### build_model(self):
build the whole GAN model, define training loss for both generator and discrimator.
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
Borrow this photo from the original DC-GAN paper.
</p>
## Discussion on Engine Functions required to build GAN
- Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly)
- Different optimizers responsible for optimizing different loss.
......@@ -151,6 +151,10 @@ def build_model(self):
```
## Main function for the demo:
Generally, the user of GAN just need to the following things:
- Define an object as DCGAN class;
- Build the DCGAN model;
- Specify two optimizers for two different losses with respect to different parameters.
```python
# pd for short, should be more concise.
from paddle.v2 as pd
......@@ -158,7 +162,6 @@ import numpy as np
import logging
if __name__ == "__main__":
# dcgan
dcgan = DCGAN()
dcgan.build_model()
......@@ -167,8 +170,8 @@ if __name__ == "__main__":
data_X, data_y = self.load_mnist()
# Two subgraphs required!!!
d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.d_loss, )
g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.g_loss)
d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.d_loss, dcgan.theta_D)
g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.g_loss, dcgan.theta_G)
# executor
sess = pd.executor()
......@@ -183,11 +186,11 @@ if __name__ == "__main__":
batch_z = np.random.uniform(-1., 1., [batch_size, z_dim])
if batch_id % 2 == 0:
sess.eval(d_optim,
sess.run(d_optim,
feed_dict = {dcgan.images: batch_im,
dcgan.y: batch_label,
dcgan.z: batch_z})
else:
sess.eval(g_optim,
sess.run(g_optim,
feed_dict = {dcgan.z: batch_z})
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册