未验证 提交 aa475557 编写于 作者: L lvmengsi 提交者: GitHub

Fix cgan and dcgan (#2828)

* fix bn warning
上级 03dbae4c
...@@ -35,6 +35,10 @@ class CGAN_model(object): ...@@ -35,6 +35,10 @@ class CGAN_model(object):
self.gf_dim = 128 self.gf_dim = 128
self.df_dim = 64 self.df_dim = 64
self.leaky_relu_factor = 0.2 self.leaky_relu_factor = 0.2
if self.batch_size == 1:
self.norm = None
else:
self.norm = "batch_norm"
def network_G(self, input, label, name="generator"): def network_G(self, input, label, name="generator"):
# concat noise and label # concat noise and label
...@@ -43,14 +47,14 @@ class CGAN_model(object): ...@@ -43,14 +47,14 @@ class CGAN_model(object):
o_l1 = linear( o_l1 = linear(
xy, xy,
self.gf_dim * 8, self.gf_dim * 8,
norm='batch_norm', norm=self.norm,
activation_fn='relu', activation_fn='relu',
name=name + '_l1') name=name + '_l1')
o_c1 = fluid.layers.concat([o_l1, y], 1) o_c1 = fluid.layers.concat([o_l1, y], 1)
o_l2 = linear( o_l2 = linear(
o_c1, o_c1,
self.gf_dim * (self.img_w // 4) * (self.img_h // 4), self.gf_dim * (self.img_w // 4) * (self.img_h // 4),
norm='batch_norm', norm=self.norm,
activation_fn='relu', activation_fn='relu',
name=name + '_l2') name=name + '_l2')
o_r1 = fluid.layers.reshape( o_r1 = fluid.layers.reshape(
...@@ -107,7 +111,7 @@ class CGAN_model(object): ...@@ -107,7 +111,7 @@ class CGAN_model(object):
o_l3 = linear( o_l3 = linear(
o_c2, o_c2,
self.df_dim * 16, self.df_dim * 16,
norm='batch_norm', norm=self.norm,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + '_l3') name=name + '_l3')
o_c3 = fluid.layers.concat([o_l3, y], 1) o_c3 = fluid.layers.concat([o_l3, y], 1)
......
...@@ -31,13 +31,17 @@ class DCGAN_model(object): ...@@ -31,13 +31,17 @@ class DCGAN_model(object):
self.dfc_dim = 1024 self.dfc_dim = 1024
self.gf_dim = 64 self.gf_dim = 64
self.df_dim = 64 self.df_dim = 64
if self.batch_size == 1:
self.norm = None
else:
self.norm = "batch_norm"
def network_G(self, input, name="generator"): def network_G(self, input, name="generator"):
o_l1 = linear(input, self.gfc_dim, norm='batch_norm', name=name + '_l1') o_l1 = linear(input, self.gfc_dim, norm=self.norm, name=name + '_l1')
o_l2 = linear( o_l2 = linear(
o_l1, o_l1,
self.gf_dim * 2 * self.img_dim // 4 * self.img_dim // 4, self.gf_dim * 2 * self.img_dim // 4 * self.img_dim // 4,
norm='batch_norm', norm=self.norm,
name=name + '_l2') name=name + '_l2')
o_r1 = fluid.layers.reshape( o_r1 = fluid.layers.reshape(
o_l2, [-1, self.df_dim * 2, self.img_dim // 4, self.img_dim // 4]) o_l2, [-1, self.df_dim * 2, self.img_dim // 4, self.img_dim // 4])
...@@ -85,7 +89,7 @@ class DCGAN_model(object): ...@@ -85,7 +89,7 @@ class DCGAN_model(object):
o_l1 = linear( o_l1 = linear(
o_c2, o_c2,
self.dfc_dim, self.dfc_dim,
norm='batch_norm', norm=self.norm,
activation_fn='leaky_relu', activation_fn='leaky_relu',
name=name + '_l1') name=name + '_l1')
out = linear(o_l1, 1, activation_fn='sigmoid', name=name + '_l2') out = linear(o_l1, 1, activation_fn='sigmoid', name=name + '_l2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册