提交 cbc9a59c 编写于 作者: D dangqingqing

Allow uers to specify the name of moving mean and variance in batch_norm interface.

上级 292c1951
...@@ -1478,7 +1478,9 @@ def batch_norm(input, ...@@ -1478,7 +1478,9 @@ def batch_norm(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
data_layout='NCHW', data_layout='NCHW',
name=None): name=None,
moving_mean_name=None,
moving_variance_name=None):
""" """
This function helps create an operator to implement This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters. the BatchNorm layer using the configurations from the input parameters.
...@@ -1508,6 +1510,7 @@ def batch_norm(input, ...@@ -1508,6 +1510,7 @@ def batch_norm(input,
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable( mean = helper.create_global_variable(
name=moving_mean_name,
dtype=input.dtype, dtype=input.dtype,
shape=param_shape, shape=param_shape,
persistable=True, persistable=True,
...@@ -1515,6 +1518,7 @@ def batch_norm(input, ...@@ -1515,6 +1518,7 @@ def batch_norm(input,
helper.set_variable_initializer(var=mean, initializer=Constant(0.0)) helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable( variance = helper.create_global_variable(
name=moving_variance_name,
dtype=input.dtype, dtype=input.dtype,
shape=param_shape, shape=param_shape,
persistable=True, persistable=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册