From b0ec8efbc66c023d4be5e9f06c45b0627a4835b7 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Mon, 26 Sep 2022 15:32:33 +0800 Subject: [PATCH] add batch_norm orig2prim transform rule (#46446) * Support both use_calc_stream and sync_op in send recv APIs (#46023) * add batch_norm prim2orig rule Co-authored-by: Wen Sun <35923278+HermitSun@users.noreply.github.com> --- .../unittests/autograd/test_orig2prim.py | 76 +++++++++++++++ .../tests/unittests/autograd/test_primapi.py | 34 ++++++- python/paddle/incubate/autograd/primops.py | 94 +++++++++++++++++++ python/paddle/incubate/autograd/primrules.py | 51 ++++++++-- 4 files changed, 247 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index d289f3dad9c..8b3c0a2be90 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -900,5 +900,81 @@ class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestBatchnormOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'batch_norm' + x = paddle.static.data(name='X', shape=[5, 8], dtype='float') + m = paddle.static.data(name='Mean', shape=[8], dtype='float') + v = paddle.static.data(name='Variance', shape=[8], dtype='float') + w = paddle.static.data(name='Scale', shape=[8], dtype='float') + b = paddle.static.data(name='Bias', shape=[8], dtype='float') + + self.input = { + "X": [x], + "Scale": [w], + "Bias": [b], + "Mean": [m], + "Variance": [v] + } + saved_variance = self.layer_help.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + batch_norm_out = self.layer_help.create_variable_for_type_inference( + x.dtype) + saved_mean = self.layer_help.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + self.output = { + "Y": [batch_norm_out], + "MeanOut": [m], + "VarianceOut": [v], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + self.attrs = { + "momentum": 0.9, + "epsilon": 1e-5, + "is_test": False, + "data_layout": 'NCHW', + "use_mkldnn": False, + "fuse_with_relu": False, + "use_global_stats": False, + "trainable_statistics": False, + } + self.orig2prim_args = (b, m, None, w, v, x) + self.all_ops = [ + 'add_p', 'add_p', 'add_p', 'add_p', 'batch_norm', 'broadcast_p', + 'broadcast_p', 'broadcast_p', 'broadcast_p', 'broadcast_p', 'div_p', + 'div_p', 'div_p', 'fill_constant_p', 'fill_constant_p', + 'fill_constant_p', 'fill_constant_p', 'fill_constant_p', + 'fill_constant_p', 'fill_constant_p', 'fill_constant_p', + 'fill_constant_p', 'mul_p', 'mul_p', 'mul_p', 'mul_p', 'mul_p', + 'pow_p', 'reduce_sum_p', 'reduce_sum_p', 'reshape_p', 'reshape_p', + 'reshape_p', 'reshape_p', 'sqrt_p', 'sub_p', 'sub_p', 'sub_p', + 'sub_p' + ] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {} + + +class TestFillConstantOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'fill_constant' + + self.attrs = {'value': 1., 'shape': (2, 3), 'dtype': paddle.float32} + self.input = {} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.float32) + } + + self.orig2prim_args = (None, None, None) + self.all_ops = ['fill_constant', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 1a086e12f20..cd5004a815a 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -318,7 +318,39 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) lambda x: paddle.var(x, axis=1, unbiased=False), (np.random.rand(10, 20, 30), ), None, 'float32'), ('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True), - (np.random.rand(10, 20, 30), ), None, 'float32'))) + (np.random.rand(10, 20, 30), ), None, 'float32'), + ('bn', lambda x, w, b: paddle.nn.functional.batch_norm( + x, paddle.ones((10, )), paddle.ones( + (10, )), w, b), (np.random.rand(10, 10), np.random.rand(10), + np.random.rand(10)), None, 'float32'), + ('bn_train', lambda x, w, b: paddle.nn.functional.batch_norm( + x, paddle.ones((10, )), paddle.ones((10, )), w, b, training=True), + (np.random.rand( + 10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'), + ('bn_nhwc', lambda x, w, b: paddle.nn.functional.batch_norm( + x, + paddle.ones((10, )) + 1, + paddle.ones((10, )), + w, + b, + training=True, + data_format='NHWC', + ), (np.random.rand( + 10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'), + ('bn_global_stat', + lambda x, w, b: paddle.nn.functional.batch_norm(x, + paddle.ones( + (10, )) + 3.2, + paddle.ones( + (10, )) + 6.7, + w, + b, + training=True, + data_format='NHWC', + use_global_stats=True), + (np.random.rand( + 10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'), + )) class TestGrad(unittest.TestCase): def setUp(self): diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index a000bec277e..e7002ece693 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import operator + import paddle from paddle.fluid.layer_helper import LayerHelper @@ -92,6 +95,97 @@ def set_value(x, y, axis, starts, ends, strides, out): return out +def mean(x, axis=None, keepdim=False): + axes = axis or tuple(range(0, len(x.shape))) + sum = reduce_sum(x, axis=axes, keepdim=keepdim) + norm = fill_const(shape=sum.shape, + value=functools.reduce(operator.mul, + [x.shape[axis] for axis in axes]), + dtype=sum.dtype) + return div(sum, norm) + + +def ones(shape, dtype): + return fill_const(1, shape, dtype) + + +def zeros(shape, dtype): + return fill_const(0, shape, dtype) + + +def batch_norm(x, + axis, + gamma, + beta, + run_mean, + run_var, + eps=1e-5, + momentum=0.9, + use_run_stat=False, + reserve_space=None): + """batch normalizer. + + Args: + x (Tensor): A tensor to be normalized. + axis (int): The features axis. + gamma (Tensor): The scale factor. + beta (float): The shift factor. + run_mean (Tensor): Running mean. + run_var (Tensor): Running variance. + eps (float, optional): A value added to the denominator for numerical + stability. Defaults to 1e-5. + momentum (float, optional): The value used for the running_mean and + running_var computation. Can be set to None for cumulative moving + average (i.e. simple average). Defaults to 0.9. + use_run_stat (bool, optional): Whether or not using runing statistics. + Defaults to False. + """ + reduce_axes = tuple(i for i in range(len(x.shape)) if i != axis) + stats_shape = tuple(1 if i in reduce_axes else s + for i, s in enumerate(x.shape)) + + batch_mean = zeros(run_mean.shape, run_mean.dtype) + batch_var = zeros(run_var.shape, run_var.dtype) + + if not use_run_stat: + batch_mean = mean(x, reduce_axes, keepdim=True) + batch_var = mean(square(sub(x, broadcast(batch_mean, x.shape))), + reduce_axes, + keepdim=True) + x_hat = div( + sub(x, broadcast(batch_mean, x.shape)), + sqrt( + add(broadcast(batch_var, x.shape), + fill_const(eps, x.shape, batch_var.dtype)))) + + momentum = fill_const(momentum, run_mean.shape, run_mean.dtype) + run_mean = add( + mul(momentum, run_mean), + mul(sub(ones(run_mean.shape, run_mean.dtype), momentum), + reshape(batch_mean, run_mean.shape))) + run_var = add( + mul(momentum, run_var), + mul(sub(ones(run_var.shape, run_var.dtype), momentum), + reshape(batch_var, run_var.shape))) + else: + x_hat = div( + sub(x, broadcast(reshape(run_mean, stats_shape), x.shape)), + sqrt( + add(broadcast(reshape(run_var, stats_shape), x.shape), + fill_const(eps, x.shape, x.dtype)))) + y = add(mul(broadcast(reshape(gamma, stats_shape), x_hat.shape), x_hat), + broadcast(reshape(beta, stats_shape), x_hat.shape)) + + if reserve_space: + return run_mean, reserve_space, batch_mean, batch_var, run_var, y + else: + return run_mean, batch_mean, batch_var, run_var, y + + +def square(x): + return pow(x, fill_const(2., x.shape, x.dtype)) + + @REGISTER_FN('add_p', 'X', 'Y', 'Z') def add(x, y, out=None): return _simple_binop(LayerHelper('add_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 4dbcc421498..a6a29e04184 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -215,6 +215,20 @@ def fill_any_like_orig2prim(op, x): convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')]))) +@REGISTER_ORIG2PRIM('fill_constant') +def fill_const_orig2prim(op, + shape_tensor=None, + shape_tensor_list=None, + value_tensor=None): + if shape_tensor or shape_tensor_list or value_tensor: + raise TypeError( + 'fill_const_orig2prim currently not support Tensor input of shape and value.' + ) + return fill_const(value=op.attr('value'), + shape=op.attr('shape'), + dtype=paddle.dtype(op.attr('dtype'))) + + @REGISTER_ORIG2PRIM('sum') def sum_orig2prim(op, xs): x0 = xs[0] @@ -391,7 +405,7 @@ def pow_orig2prim(op, x, y): @REGISTER_ORIG2PRIM('square') def square_orig2prim(op, x): - return primops.pow(x, fill_const(2., x.shape, x.dtype)) + return primops.square(x) @REGISTER_ORIG2PRIM('elementwise_max') @@ -436,12 +450,35 @@ def reduce_sum_orig2prim(op, x): def reduce_mean_orig2prim(op, x): axes = tuple(range(0, len( x.shape))) if op.attr('reduce_all') else op.attr('dim') - sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim')) - norm = fill_const(shape=sum.shape, - value=functools.reduce(operator.mul, - [x.shape[axis] for axis in axes]), - dtype=sum.dtype) - return div(sum, norm) + return primops.mean(x, axes, op.attr('keep_dim')) + + +@REGISTER_ORIG2PRIM('batch_norm') +def batch_norm_orig2prim(op, bias, run_mean, momentum_tensor, scale, run_var, + x): + momentum = op.attr('momentum') + eps = op.attr('epsilon') + is_test = op.attr('is_test') + data_layout = op.attr('data_layout') + use_global_stats = op.attr('use_global_stats') + trainable_statistics = op.attr('trainable_statistics') + reserve_space = None if len( + op.output_names) == 5 else get_output_var_list(op)[1] + + feature_axis = 1 if data_layout in ('NC', 'NCL', 'NCHW', + 'NCHWD') else len(x.shape) - 1 + use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats + + return primops.batch_norm(x, + feature_axis, + scale, + bias, + run_mean, + run_var, + eps=eps, + momentum=momentum, + use_run_stat=use_run_stat, + reserve_space=reserve_space) @REGISTER_ORIG2PRIM('size') -- GitLab