From 1dedaada9b84e0e0529599a28f67c1f05c7ce59e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 7 Feb 2023 20:39:41 +0800 Subject: [PATCH] Remove axis in some elementwise api (#50190) * remove axis in some elementwise api * fix inplace bug eager-gen * fix bug * revert change for CheckInplace * polish code --- .../generator/eager_gen.py | 58 +++++----- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- python/paddle/fluid/layers/nn.py | 22 ---- .../ir/inference/test_trt_elementwise_op.py | 8 +- .../unittests/test_elementwise_nn_grad.py | 2 +- python/paddle/tensor/math.py | 103 +++--------------- 6 files changed, 51 insertions(+), 144 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 3497a1217c..2e720f8800 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -261,6 +261,9 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """ // Get Outputs {} VLOG(4) << \"Finish AD API: {}"; + + // Check Inplace if needed +{}{} // LOG IF DEBUG {} // Returns @@ -1462,8 +1465,31 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): returns_str = ", ".join(returns_list) returns_str = f"{returns_type_str}{{{returns_str}}}" + # Check Inplace + check_inplace_str = "" + bump_inplace_version_str = "" + # Note: When the name of original api in yaml is end of '_', that means this api is a + # special inplace api and it doesn't require checking and bumping version (except assign_out_). + # This rule is obscure, so we maybe replace it by adding new design in the future. + if is_inplaced and ( + self.forward_api_name[-1] != '_' + or self.forward_api_name == 'assign_out_' + ): + for inplace_name in forward_inplace_map.keys(): + if ( + not self.is_forward_only + and forward_api_name not in inplace_check_blacklist + ): + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, GetAutoGradMetaName(inplace_name) + ) + bump_inplace_version_str += ( + BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name + ) + ) + # Node Creation Pre-Processing - inputs_names = [] if not self.is_forward_only: # 1. Get Input AutoGradMeta inputs_autograd_meta_list = [] @@ -1478,13 +1504,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ) in backward_grad_outputs_map.items(): if pos == corresponding_pos: has_corresponding_grad_output = True - if ( - has_corresponding_grad_output - or ( - name in forward_inplace_map - and forward_api_name not in inplace_check_blacklist - ) - or self.is_forward_only + if has_corresponding_grad_output or ( + name in forward_inplace_map + and forward_api_name not in inplace_check_blacklist ): input_autograd_meta_name = GetAutoGradMetaName(name) if IsPlainTensorType(ttype): @@ -1532,24 +1554,6 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - # 3. Check Inplace - check_inplace_str = "" - bump_inplace_version_str = "" - if is_inplaced: - for inplace_name in forward_inplace_map.keys(): - if forward_api_name not in inplace_check_blacklist: - inplace_autograd_meta_name = GetAutoGradMetaName( - inplace_name - ) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name - ) - bump_inplace_version_str += ( - BUMP_INPLACE_VERSION_TEMPLATE.format( - inplace_name, inplace_name - ) - ) - # Node Creation self.GenerateNodeCreationCodes() node_creation_str = self.node_creation_str @@ -1643,6 +1647,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): forward_call_str, get_outputs_str, forward_api_name, + check_inplace_str, + bump_inplace_version_str, log_str, returns_str, ) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 286e9841ef..94bcff5b89 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1429,7 +1429,7 @@ - op : remainder args : (Tensor x, Tensor y) - output : Tensor + output : Tensor (out) infer_meta : func : ElementwiseInferMeta kernel : diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1dd819df41..4cb8b8f88f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -101,28 +101,6 @@ def _get_reduce_dim(dim, input): return reduce_all, dim -@dygraph_only -def _elementwise_op_in_dygraph( - x, y, axis=-1, act=None, use_mkldnn=False, op_name=None -): - def is_inplace(op_name): - return op_name[-1] == "_" - - if op_name not in OP_NAMEMAPPING.keys() or axis != -1: - op = getattr(_legacy_C_ops, op_name) - out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) - else: - if in_dygraph_mode(): - op = getattr( - _C_ops, - OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name, - ) - out = op(x, y) - return dygraph_utils._append_activation_in_dygraph( - out, act, use_mkldnn=use_mkldnn - ) - - @deprecated(since="2.0.0", update_to="paddle.nn.functional.embedding") def embedding( input, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py index 7674a22658..9505504060 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py @@ -50,7 +50,7 @@ class TensorRTSubgraphPassElementwiseBroadcastTest(InferencePassTest): self.fetch_list = [out] def append_eltwise(self, data1, data2): - return paddle.tensor.math._add_with_axis(x=data1, y=data2, axis=0) + return paddle.tensor.math.add(x=data1, y=data2) def test_check_output(self): if os.path.exists(self.path + "_opt_cache"): @@ -67,21 +67,21 @@ class TensorRTSubgraphPassElementwiseBroadcastTest1( TensorRTSubgraphPassElementwiseBroadcastTest ): def append_eltwise(self, data1, data2): - return paddle.tensor.math._subtract_with_axis(x=data1, y=data2, axis=0) + return paddle.tensor.math.subtract(x=data1, y=data2) class TensorRTSubgraphPassElementwiseBroadcastTest2( TensorRTSubgraphPassElementwiseBroadcastTest ): def append_eltwise(self, data1, data2): - return paddle.tensor.math._multiply_with_axis(x=data1, y=data2, axis=0) + return paddle.tensor.math.multiply(x=data1, y=data2) class TensorRTSubgraphPassElementwiseBroadcastTest3( TensorRTSubgraphPassElementwiseBroadcastTest ): def append_eltwise(self, data1, data2): - return paddle.tensor.math._divide_with_axis(x=data1, y=data2, axis=0) + return paddle.tensor.math.divide(x=data1, y=data2) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py b/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py index 2c5da64817..0b9c8c9a54 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py @@ -222,7 +222,7 @@ class TestElementwiseDivDoubleGradCheck(unittest.TestCase): y = paddle.static.data('y', shape, dtype) x.persistable = True y.persistable = True - out = paddle.tensor.math._divide_with_axis(x, y, axis=0) + out = paddle.tensor.math.divide(x, y) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) y_arr = np.random.uniform(-1, 1, shape).astype(dtype) y_arr[np.abs(y_arr) < 0.005] = 0.02 diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6f797b82e1..edd44e4e83 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -480,31 +480,6 @@ OP_NAMEMAPPING = { } -@dygraph_only -def _elementwise_op_in_dygraph( - x, y, axis=-1, act=None, use_mkldnn=False, op_name=None -): - def is_inplace(op_name): - return op_name[-1] == "_" - - if op_name not in OP_NAMEMAPPING.keys() or axis != -1: - op = getattr(_legacy_C_ops, op_name) - out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) - else: - if in_dygraph_mode(): - op = getattr( - _C_ops, - OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name, - ) - out = op(x, y) - if act is None: - return out - else: - return dygraph_utils._append_activation_in_dygraph( - out, act, use_mkldnn=use_mkldnn - ) - - def _elementwise_op(helper): op_type = helper.layer_type original_op_type = helper.kwargs.get('original_op_type', op_type) @@ -616,8 +591,6 @@ def add_(x, y, name=None): Inplace version of ``add`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_tensor_add`. """ - op_type = 'elementwise_add_' - axis = -1 out_shape = broadcast_shape(x.shape, y.shape) if out_shape != x.shape: @@ -627,11 +600,7 @@ def add_(x, y, name=None): ) ) - if in_dygraph_mode(): - return _C_ops.add_(x, y) - else: - out = _elementwise_op_in_dygraph(x, y, axis=axis, op_name=op_type) - return out + return _C_ops.add_(x, y) def subtract(x, y, name=None): @@ -690,13 +659,10 @@ def subtract(x, y, name=None): # Tensor(shape=[3], dtype=float64, place=Place(cpu), stop_gradient=True, # [ 4. , inf., -inf.]) """ - op_type = 'elementwise_sub' - axis = -1 - act = None if in_dygraph_mode(): return _C_ops.subtract(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_sub', **locals())) @inplace_apis_in_dygraph_only @@ -705,8 +671,6 @@ def subtract_(x, y, name=None): Inplace version of ``subtract`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_tensor_subtract`. """ - axis = -1 - act = None out_shape = broadcast_shape(x.shape, y.shape) if out_shape != x.shape: @@ -716,13 +680,7 @@ def subtract_(x, y, name=None): ) ) - if in_dygraph_mode(): - return _C_ops.subtract_(x, y) - else: - out = _elementwise_op_in_dygraph( - x, y, axis=axis, act=act, op_name='elementwise_sub_' - ) - return out + return _C_ops.subtract_(x, y) def divide(x, y, name=None): @@ -757,13 +715,10 @@ def divide(x, y, name=None): print(z) # [2., 0.6, 2.] """ - op_type = 'elementwise_div' - axis = -1 - act = None if in_dygraph_mode(): return _C_ops.divide(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_div', **locals())) def floor_divide(x, y, name=None): @@ -800,12 +755,10 @@ def floor_divide(x, y, name=None): print(z) # [2, 0, 2, 2] """ - op_type = 'elementwise_floordiv' - axis = -1 if in_dygraph_mode(): return _C_ops.floor_divide(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_floordiv', **locals())) def remainder(x, y, name=None): @@ -841,13 +794,10 @@ def remainder(x, y, name=None): print(z) # [0, 3, 2, 1] """ - op_type = 'elementwise_mod' - axis = -1 - if in_dygraph_mode(): return _C_ops.remainder(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_mod', **locals())) @inplace_apis_in_dygraph_only @@ -856,9 +806,6 @@ def remainder_(x, y, name=None): Inplace version of ``remainder`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_tensor_remainder`. """ - op_type = 'elementwise_mod_' - axis = -1 - out_shape = broadcast_shape(x.shape, y.shape) if out_shape != x.shape: raise ValueError( @@ -866,8 +813,7 @@ def remainder_(x, y, name=None): out_shape, x.shape ) ) - - return _elementwise_op_in_dygraph(x, y, axis=axis, op_name=op_type) + return _C_ops.remainder_(x, y) mod = remainder # noqa: F841 @@ -911,10 +857,6 @@ def multiply(x, y, name=None): print(res) # [[[2, 4, 6], [2, 4, 6]]] """ - op_type = 'elementwise_mul' - act = None - axis = -1 - if in_dygraph_mode(): return _C_ops.multiply(x, y) else: @@ -924,7 +866,7 @@ def multiply(x, y, name=None): % (x.dtype, y.dtype) ) - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_mul', **locals())) @dygraph_only @@ -958,7 +900,6 @@ def _add_with_axis(x, y, axis=-1, name=None): return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "add") else: op_type = 'elementwise_add' - act = None return _elementwise_op(LayerHelper(op_type, **locals())) @@ -970,7 +911,6 @@ def _subtract_with_axis(x, y, axis=-1, name=None): ) else: op_type = 'elementwise_sub' - act = None return _elementwise_op(LayerHelper(op_type, **locals())) @@ -982,7 +922,6 @@ def _multiply_with_axis(x, y, axis=-1, name=None): ) else: op_type = 'elementwise_mul' - act = None return _elementwise_op(LayerHelper(op_type, **locals())) @@ -992,7 +931,6 @@ def _divide_with_axis(x, y, axis=-1, name=None): return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "divide") else: op_type = 'elementwise_div' - act = None return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1052,13 +990,10 @@ def maximum(x, y, name=None): # Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, # [5. , 3. , inf.]) """ - op_type = 'elementwise_max' - axis = -1 - act = None if in_dygraph_mode(): return _C_ops.maximum(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_max', **locals())) def minimum(x, y, name=None): @@ -1117,13 +1052,10 @@ def minimum(x, y, name=None): # Tensor(shape=[3], dtype=float64, place=Place(cpu), stop_gradient=True, # [ 1. , -inf., 5. ]) """ - op_type = 'elementwise_min' - axis = -1 - act = None if in_dygraph_mode(): return _C_ops.minimum(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_min', **locals())) def fmax(x, y, name=None): @@ -1184,13 +1116,10 @@ def fmax(x, y, name=None): # Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, # [5. , 3. , inf.]) """ - op_type = 'elementwise_fmax' - axis = -1 - act = None if in_dygraph_mode(): return _C_ops.fmax(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_fmax', **locals())) def fmin(x, y, name=None): @@ -1251,13 +1180,10 @@ def fmin(x, y, name=None): # Tensor(shape=[3], dtype=float64, place=Place(cpu), stop_gradient=True, # [ 1. , -inf., 5. ]) """ - op_type = 'elementwise_fmin' - axis = -1 - act = None if in_dygraph_mode(): return _C_ops.fmin(x, y) else: - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_fmin', **locals())) def sum(x, axis=None, dtype=None, keepdim=False, name=None): @@ -4888,9 +4814,6 @@ def frac(x, name=None): # [[ 0.22000003, -0.02999997], # [-0.54999995, 0.66000003]]) """ - op_type = 'elementwise_sub' - axis = -1 - act = None if x.dtype not in [ paddle.int32, paddle.int64, @@ -4917,7 +4840,7 @@ def frac(x, name=None): helper.append_op( type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y} ) - return _elementwise_op(LayerHelper(op_type, **locals())) + return _elementwise_op(LayerHelper('elementwise_sub', **locals())) def sgn(x, name=None): -- GitLab