未验证 提交 aa475557 编写于 作者: L lvmengsi 提交者: GitHub

Fix cgan and dcgan (#2828)

* fix bn warning
上级 03dbae4c
......@@ -35,6 +35,10 @@ class CGAN_model(object):
self.gf_dim = 128
self.df_dim = 64
self.leaky_relu_factor = 0.2
if self.batch_size == 1:
self.norm = None
else:
self.norm = "batch_norm"
def network_G(self, input, label, name="generator"):
# concat noise and label
......@@ -43,14 +47,14 @@ class CGAN_model(object):
o_l1 = linear(
xy,
self.gf_dim * 8,
norm='batch_norm',
norm=self.norm,
activation_fn='relu',
name=name + '_l1')
o_c1 = fluid.layers.concat([o_l1, y], 1)
o_l2 = linear(
o_c1,
self.gf_dim * (self.img_w // 4) * (self.img_h // 4),
norm='batch_norm',
norm=self.norm,
activation_fn='relu',
name=name + '_l2')
o_r1 = fluid.layers.reshape(
......@@ -107,7 +111,7 @@ class CGAN_model(object):
o_l3 = linear(
o_c2,
self.df_dim * 16,
norm='batch_norm',
norm=self.norm,
activation_fn='leaky_relu',
name=name + '_l3')
o_c3 = fluid.layers.concat([o_l3, y], 1)
......
......@@ -31,13 +31,17 @@ class DCGAN_model(object):
self.dfc_dim = 1024
self.gf_dim = 64
self.df_dim = 64
if self.batch_size == 1:
self.norm = None
else:
self.norm = "batch_norm"
def network_G(self, input, name="generator"):
o_l1 = linear(input, self.gfc_dim, norm='batch_norm', name=name + '_l1')
o_l1 = linear(input, self.gfc_dim, norm=self.norm, name=name + '_l1')
o_l2 = linear(
o_l1,
self.gf_dim * 2 * self.img_dim // 4 * self.img_dim // 4,
norm='batch_norm',
norm=self.norm,
name=name + '_l2')
o_r1 = fluid.layers.reshape(
o_l2, [-1, self.df_dim * 2, self.img_dim // 4, self.img_dim // 4])
......@@ -85,7 +89,7 @@ class DCGAN_model(object):
o_l1 = linear(
o_c2,
self.dfc_dim,
norm='batch_norm',
norm=self.norm,
activation_fn='leaky_relu',
name=name + '_l1')
out = linear(o_l1, 1, activation_fn='sigmoid', name=name + '_l2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册