From cbc9a59c33b507f26ad4e00e740672ef99bc8fa4 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Fri, 2 Feb 2018 10:03:09 +0800 Subject: [PATCH] Allow uers to specify the name of moving mean and variance in batch_norm interface. --- python/paddle/v2/fluid/layers/nn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index c38e21087d..cb8a4815db 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1478,7 +1478,9 @@ def batch_norm(input, param_attr=None, bias_attr=None, data_layout='NCHW', - name=None): + name=None, + moving_mean_name=None, + moving_variance_name=None): """ This function helps create an operator to implement the BatchNorm layer using the configurations from the input parameters. @@ -1508,6 +1510,7 @@ def batch_norm(input, attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) mean = helper.create_global_variable( + name=moving_mean_name, dtype=input.dtype, shape=param_shape, persistable=True, @@ -1515,6 +1518,7 @@ def batch_norm(input, helper.set_variable_initializer(var=mean, initializer=Constant(0.0)) variance = helper.create_global_variable( + name=moving_variance_name, dtype=input.dtype, shape=param_shape, persistable=True, -- GitLab