From 79c8bb9e7acbe2bc91625e4a2e396994c4fef168 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 4 Oct 2017 16:02:07 -0700 Subject: [PATCH] gan design new version --- doc/design/gan_api.md | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md index b107f2fc000..8521bc8bf26 100644 --- a/doc/design/gan_api.md +++ b/doc/design/gan_api.md @@ -6,6 +6,11 @@ It contains several important machine learning concepts, including building and In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN as an example due to its good performance on image generation. +

+
+Borrow this photo from the original DC-GAN paper. +

+ ## The Conditional-GAN might be a class. This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure: @@ -26,11 +31,6 @@ Returns a 0/1 binary label. ### build_model(self): build the whole GAN model, define training loss for both generator and discrimator. -

-
-Borrow this photo from the original DC-GAN paper. -

- ## Discussion on Engine Functions required to build GAN - Trace the ternsor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly) - Different optimizers responsible for optimizing different loss. @@ -151,6 +151,10 @@ def build_model(self): ``` ## Main function for the demo: +Generally, the user of GAN just need to the following things: +- Define an object as DCGAN class; +- Build the DCGAN model; +- Specify two optimizers for two different losses with respect to different parameters. ```python # pd for short, should be more concise. from paddle.v2 as pd @@ -158,7 +162,6 @@ import numpy as np import logging if __name__ == "__main__": - # dcgan dcgan = DCGAN() dcgan.build_model() @@ -167,8 +170,8 @@ if __name__ == "__main__": data_X, data_y = self.load_mnist() # Two subgraphs required!!! - d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.d_loss, ) - g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(self.g_loss) + d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.d_loss, dcgan.theta_D) + g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.g_loss, dcgan.theta_G) # executor sess = pd.executor() @@ -183,11 +186,11 @@ if __name__ == "__main__": batch_z = np.random.uniform(-1., 1., [batch_size, z_dim]) if batch_id % 2 == 0: - sess.eval(d_optim, + sess.run(d_optim, feed_dict = {dcgan.images: batch_im, dcgan.y: batch_label, dcgan.z: batch_z}) else: - sess.eval(g_optim, + sess.run(g_optim, feed_dict = {dcgan.z: batch_z}) ``` -- GitLab