未验证 提交 4e7e39b4 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #8069 from qingqing01/bn_name

Allow uers to specify the name of moving mean and variance in batch_norm interface.
......@@ -1478,7 +1478,9 @@ def batch_norm(input,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
name=None):
name=None,
moving_mean_name=None,
moving_variance_name=None):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
......@@ -1508,6 +1510,7 @@ def batch_norm(input,
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable(
name=moving_mean_name,
dtype=input.dtype,
shape=param_shape,
persistable=True,
......@@ -1515,6 +1518,7 @@ def batch_norm(input,
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable(
name=moving_variance_name,
dtype=input.dtype,
shape=param_shape,
persistable=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册