提交 35a5b9b9 编写于 作者: Z zchen0211

gan api

上级 4767fb67
...@@ -222,6 +222,10 @@ import logging ...@@ -222,6 +222,10 @@ import logging
if __name__ == "__main__": if __name__ == "__main__":
# dcgan class in the default graph/block # dcgan class in the default graph/block
# if we use dependency engine as tensorflow
# the codes, will be slightly different like:
# dcgan = DCGAN()
# dcgan.build_model()
with pd.block() as def_block: with pd.block() as def_block:
dcgan = DCGAN() dcgan = DCGAN()
dcgan.build_model(def_block) dcgan.build_model(def_block)
...@@ -230,8 +234,12 @@ if __name__ == "__main__": ...@@ -230,8 +234,12 @@ if __name__ == "__main__":
data_X, data_y = self.load_mnist() data_X, data_y = self.load_mnist()
# Two subgraphs required!!! # Two subgraphs required!!!
d_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.d_loss, dcgan.theta_D) with pd.block().d_block():
g_optim = pd.train.Adam(lr = .001, beta= .1).minimize(dcgan.g_loss, dcgan.theta_G) d_optim = pd.train.Adam(lr = .001, beta= .1)
d_step = d_optim.minimize(dcgan.d_loss, dcgan.theta_D)
with pd.block.g_block():
g_optim = pd.train.Adam(lr = .001, beta= .1)
g_step = pd.minimize(dcgan.g_loss, dcgan.theta_G)
# executor # executor
sess = pd.executor() sess = pd.executor()
...@@ -246,11 +254,11 @@ if __name__ == "__main__": ...@@ -246,11 +254,11 @@ if __name__ == "__main__":
batch_z = np.random.uniform(-1., 1., [batch_size, z_dim]) batch_z = np.random.uniform(-1., 1., [batch_size, z_dim])
if batch_id % 2 == 0: if batch_id % 2 == 0:
sess.run(d_optim, sess.run(d_step,
feed_dict = {dcgan.images: batch_im, feed_dict = {dcgan.images: batch_im,
dcgan.y: batch_label, dcgan.y: batch_label,
dcgan.z: batch_z}) dcgan.z: batch_z})
else: else:
sess.run(g_optim, sess.run(g_step,
feed_dict = {dcgan.z: batch_z}) feed_dict = {dcgan.z: batch_z})
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册