未验证 提交 c318befa 编写于 作者: L lvmengsi 提交者: GitHub

fix_spade_init (#3722)

上级 b1673312
...@@ -40,6 +40,7 @@ class SPADE_model(object): ...@@ -40,6 +40,7 @@ class SPADE_model(object):
padding=1, padding=1,
name=name + "_fc", name=name + "_fc",
use_bias=True, use_bias=True,
initial="kaiming",
is_test=is_test) is_test=is_test)
x = self.SPADEResnetBlock( x = self.SPADEResnetBlock(
x, x,
...@@ -88,6 +89,7 @@ class SPADE_model(object): ...@@ -88,6 +89,7 @@ class SPADE_model(object):
padding=1, padding=1,
name=name + "_conv_img", name=name + "_conv_img",
use_bias=True, use_bias=True,
initial="kaiming",
is_test=is_test) is_test=is_test)
x = fluid.layers.tanh(x) x = fluid.layers.tanh(x)
...@@ -148,6 +150,7 @@ class SPADE_model(object): ...@@ -148,6 +150,7 @@ class SPADE_model(object):
padding=pw, padding=pw,
activation_fn='relu', activation_fn='relu',
name=name + ".mlp_shared.0", name=name + ".mlp_shared.0",
initial="kaiming",
use_bias=True) use_bias=True)
gamma = conv2d( gamma = conv2d(
actv, actv,
...@@ -155,6 +158,7 @@ class SPADE_model(object): ...@@ -155,6 +158,7 @@ class SPADE_model(object):
ks, ks,
padding=pw, padding=pw,
name=name + ".mlp_gamma", name=name + ".mlp_gamma",
initial="kaiming",
use_bias=True) use_bias=True)
beta = conv2d( beta = conv2d(
actv, actv,
...@@ -162,6 +166,7 @@ class SPADE_model(object): ...@@ -162,6 +166,7 @@ class SPADE_model(object):
ks, ks,
padding=pw, padding=pw,
name=name + ".mlp_beta", name=name + ".mlp_beta",
initial="kaiming",
use_bias=True) use_bias=True)
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
name=name + ".param_free_norm.weight", name=name + ".param_free_norm.weight",
...@@ -219,6 +224,7 @@ def build_discriminator_Nlayers(input, ...@@ -219,6 +224,7 @@ def build_discriminator_Nlayers(input,
name=name + ".model0.0", name=name + ".model0.0",
activation_fn='leaky_relu', activation_fn='leaky_relu',
relufactor=0.2, relufactor=0.2,
initial="kaiming",
use_bias=True) use_bias=True)
d_dims = d_base_dims d_dims = d_base_dims
res_list.append(res1) res_list.append(res1)
...@@ -248,6 +254,7 @@ def build_discriminator_Nlayers(input, ...@@ -248,6 +254,7 @@ def build_discriminator_Nlayers(input,
0.02, 0.02,
1, 1,
name + ".model{}.0".format(d_nlayers), name + ".model{}.0".format(d_nlayers),
initial="kaiming",
use_bias=True) use_bias=True)
res_list.append(o_c4) res_list.append(o_c4)
return res_list return res_list
...@@ -427,7 +427,8 @@ def conv2d_spectral_norm(input, ...@@ -427,7 +427,8 @@ def conv2d_spectral_norm(input,
dtype = helper.input_dtype() dtype = helper.input_dtype()
weight_param = fluid.ParamAttr( weight_param = fluid.ParamAttr(
name=name + ".weight_orig", name=name + ".weight_orig",
initializer=fluid.initializer.Constant(1.0), initializer=fluid.initializer.Normal(
loc=0.0, scale=1.0),
trainable=True) trainable=True)
weight = helper.create_parameter( weight = helper.create_parameter(
attr=weight_param, attr=weight_param,
...@@ -438,7 +439,9 @@ def conv2d_spectral_norm(input, ...@@ -438,7 +439,9 @@ def conv2d_spectral_norm(input,
weight = weight_spectral_norm weight = weight_spectral_norm
if use_bias: if use_bias:
bias_attr = fluid.ParamAttr( bias_attr = fluid.ParamAttr(
name=name + "_b", initializer=fluid.initializer.Constant(0.0)) name=name + "_b",
initializer=fluid.initializer.Normal(
loc=0.0, scale=1.0))
else: else:
bias_attr = False bias_attr = False
conv = conv2d_with_filter( conv = conv2d_with_filter(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册