From 35a5b9b99756188f2782ed19b4eaca57cb44ceea Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 9 Oct 2017 16:22:49 -0700 Subject: [PATCH] gan api --- doc/design/gan_api.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md index 77c867bac7..ed7622920b 100644 --- a/doc/design/gan_api.md +++ b/doc/design/gan_api.md @@ -222,6 +222,10 @@ import logging if __name__ == "__main__": # dcgan class in the default graph/block + # if we use dependency engine as tensorflow + # the codes, will be slightly different like: + # dcgan = DCGAN() + # dcgan.build_model() with pd.block() as def_block: dcgan = DCGAN() dcgan.build_model(def_block) @@ -230,8 +234,12 @@ if __name__ == "__main__": data_X, data_y = self.load_mnist() # Two subgraphs required!!! - d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.d_loss, dcgan.theta_D) - g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.g_loss, dcgan.theta_G) + with pd.block().d_block(): + d_optim = pd.train.Adam(lr = .001, beta= .1) + d_step = d_optim.minimize(dcgan.d_loss, dcgan.theta_D) + with pd.block.g_block(): + g_optim = pd.train.Adam(lr = .001, beta= .1) + g_step = pd.minimize(dcgan.g_loss, dcgan.theta_G) # executor sess = pd.executor() @@ -246,11 +254,11 @@ if __name__ == "__main__": batch_z = np.random.uniform(-1., 1., [batch_size, z_dim]) if batch_id % 2 == 0: - sess.run(d_optim, + sess.run(d_step, feed_dict = {dcgan.images: batch_im, dcgan.y: batch_label, dcgan.z: batch_z}) else: - sess.run(g_optim, + sess.run(g_step, feed_dict = {dcgan.z: batch_z}) ``` -- GitLab