提交 79c8bb9e 编写于 作者: Z zchen0211

gan design new version

上级 4147c7f2
...@@ -6,6 +6,11 @@ It contains several important machine learning concepts, including building and ...@@ -6,6 +6,11 @@ It contains several important machine learning concepts, including building and
In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation. In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation.
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
Borrow this photo from the original DC-GAN paper.
</p>
## The Conditional-GAN might be a class. ## The Conditional-GAN might be a class.
This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure: This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure:
...@@ -26,11 +31,6 @@ Returns a 0/1 binary label. ...@@ -26,11 +31,6 @@ Returns a 0/1 binary label.
### build_model(self): ### build_model(self):
build the whole GAN model, define training loss for both generator and discrimator. build the whole GAN model, define training loss for both generator and discrimator.
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
Borrow this photo from the original DC-GAN paper.
</p>
## Discussion on Engine Functions required to build GAN ## Discussion on Engine Functions required to build GAN
- Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly) - Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly)
- Different optimizers responsible for optimizing different loss. - Different optimizers responsible for optimizing different loss.
...@@ -151,6 +151,10 @@ def build_model(self): ...@@ -151,6 +151,10 @@ def build_model(self):
``` ```
## Main function for the demo: ## Main function for the demo:
Generally, the user of GAN just need to the following things:
- Define an object as DCGAN class;
- Build the DCGAN model;
- Specify two optimizers for two different losses with respect to different parameters.
```python ```python
# pd for short, should be more concise. # pd for short, should be more concise.
from paddle.v2 as pd from paddle.v2 as pd
...@@ -158,7 +162,6 @@ import numpy as np ...@@ -158,7 +162,6 @@ import numpy as np
import logging import logging
if __name__ == "__main__": if __name__ == "__main__":
# dcgan # dcgan
dcgan = DCGAN() dcgan = DCGAN()
dcgan.build_model() dcgan.build_model()
...@@ -167,8 +170,8 @@ if __name__ == "__main__": ...@@ -167,8 +170,8 @@ if __name__ == "__main__":
data_X, data_y = self.load_mnist() data_X, data_y = self.load_mnist()
# Two subgraphs required!!! # Two subgraphs required!!!
d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.d_loss, ) d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.d_loss, dcgan.theta_D)
g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.g_loss) g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.g_loss, dcgan.theta_G)
# executor # executor
sess = pd.executor() sess = pd.executor()
...@@ -183,11 +186,11 @@ if __name__ == "__main__": ...@@ -183,11 +186,11 @@ if __name__ == "__main__":
batch_z = np.random.uniform(-1., 1., [batch_size, z_dim]) batch_z = np.random.uniform(-1., 1., [batch_size, z_dim])
if batch_id % 2 == 0: if batch_id % 2 == 0:
sess.eval(d_optim, sess.run(d_optim,
feed_dict = {dcgan.images: batch_im, feed_dict = {dcgan.images: batch_im,
dcgan.y: batch_label, dcgan.y: batch_label,
dcgan.z: batch_z}) dcgan.z: batch_z})
else: else:
sess.eval(g_optim, sess.run(g_optim,
feed_dict = {dcgan.z: batch_z}) feed_dict = {dcgan.z: batch_z})
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册