未验证 提交 fe341bac 编写于 作者: Q QI JUN 提交者: GitHub

refine batch norm python layer (#7348)

上级 1566af8a
...@@ -120,11 +120,12 @@ class LayerHelper(object): ...@@ -120,11 +120,12 @@ class LayerHelper(object):
raise ValueError("no Parameter name %s found" % name) raise ValueError("no Parameter name %s found" % name)
return param return param
def create_tmp_variable(self, dtype): def create_tmp_variable(self, dtype, stop_gradient=False):
return self.main_program.current_block().create_var( return self.main_program.current_block().create_var(
name=unique_name(".".join([self.name, 'tmp'])), name=unique_name(".".join([self.name, 'tmp'])),
dtype=dtype, dtype=dtype,
persistable=False) persistable=False,
stop_gradient=stop_gradient)
def create_variable(self, *args, **kwargs): def create_variable(self, *args, **kwargs):
return self.main_program.current_block().create_var(*args, **kwargs) return self.main_program.current_block().create_var(*args, **kwargs)
......
...@@ -971,11 +971,17 @@ def batch_norm(input, ...@@ -971,11 +971,17 @@ def batch_norm(input,
attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=True) attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable( mean = helper.create_global_variable(
dtype=input.dtype, shape=param_shape, persistable=True) dtype=input.dtype,
shape=param_shape,
persistable=True,
stop_gradient=True)
helper.set_variable_initializer(var=mean, initializer=Constant(0.0)) helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable( variance = helper.create_global_variable(
dtype=input.dtype, shape=param_shape, persistable=True) dtype=input.dtype,
shape=param_shape,
persistable=True,
stop_gradient=True)
helper.set_variable_initializer(var=variance, initializer=Constant(1.0)) helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
# create output # create output
...@@ -983,8 +989,8 @@ def batch_norm(input, ...@@ -983,8 +989,8 @@ def batch_norm(input,
mean_out = mean mean_out = mean
# variance and variance out share the same memory # variance and variance out share the same memory
variance_out = variance variance_out = variance
saved_mean = helper.create_tmp_variable(dtype) saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
saved_variance = helper.create_tmp_variable(dtype) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
batch_norm_out = helper.create_tmp_variable(dtype) batch_norm_out = helper.create_tmp_variable(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册