From 7c1e7e46aa09b1b4f2f0288d9c19ae4168ce62e2 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Fri, 19 Aug 2022 19:42:53 +0800 Subject: [PATCH] call final_state method in inplace APIs (#42968) * add forward inplace final state api * fix bug * fix reshape * fix coverage * add inplace info for erfinv, lerp, put_along_axis * fix put_along_axis infer_meta * fix format * update yaml * fix --- paddle/phi/api/yaml/legacy_api.yaml | 40 +++++--- .../fluid/tests/unittests/test_inplace.py | 11 +++ python/paddle/nn/functional/activation.py | 13 ++- .../paddle/tensor/layer_function_generator.py | 4 + python/paddle/tensor/manipulation.py | 95 ++++++++++++++----- python/paddle/tensor/math.py | 24 +++-- 6 files changed, 142 insertions(+), 45 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index e37706a126b..26ada67d957 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -92,11 +92,12 @@ - api : add args : (Tensor x, Tensor y) - output : Tensor + output : Tensor(out) infer_meta : func : ElementwiseInferMeta kernel : func : add + inplace : (x -> out) backward : add_grad - api : add_n @@ -470,6 +471,7 @@ func : UnchangedInferMeta kernel : func : ceil + inplace : (x -> out) backward : ceil_grad - api : celu @@ -779,12 +781,13 @@ # elu - api : elu args : (Tensor x, float alpha) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : elu + inplace : (x -> out) backward : elu_grad - api : embedding @@ -836,11 +839,12 @@ # exp - api : exp args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : exp + inplace : (x -> out) backward : exp_grad # expand @@ -958,6 +962,7 @@ func : UnchangedInferMeta kernel : func : floor + inplace : (x -> out) backward : floor_grad - api : floor_divide @@ -1432,11 +1437,12 @@ - api : lerp args : (Tensor x, Tensor y, Tensor weight) - output : Tensor + output : Tensor(out) infer_meta : func : LerpInferMeta kernel : func : lerp + inplace : (x -> out) backward : lerp_grad - api : less_equal @@ -2052,13 +2058,14 @@ # put_along_axis - api : put_along_axis args : (Tensor x, Tensor index, Tensor value, int axis, str reduce) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta - param : [index] + param : [x] kernel : func : put_along_axis data_type : x + inplace : (x -> out) backward : put_along_axis_grad - api : qr @@ -2105,11 +2112,12 @@ - api : reciprocal args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : reciprocal + inplace : (x -> out) backward : reciprocal_grad # reduce_prod @@ -2253,6 +2261,7 @@ func : UnchangedInferMeta kernel : func : round + inplace : (x -> out) backward : round_grad - api : rsqrt @@ -2279,12 +2288,13 @@ - api : scatter args : (Tensor x, Tensor index, Tensor updates, bool overwrite) - output : Tensor + output : Tensor(out) infer_meta : func : ScatterInferMeta dtype : x kernel : func : scatter + inplace : (x -> out) backward : scatter_grad - api : scatter_nd_add @@ -2463,12 +2473,13 @@ - api : softmax args : (Tensor x, int axis) - output : Tensor + output : Tensor(out) infer_meta : func : SoftmaxInferMeta kernel : func : softmax use_gpudnn : true + inplace : (x -> out) backward : softmax_grad - api : softplus @@ -2510,11 +2521,12 @@ - api : sqrt args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : sqrt + inplace : (x -> out) backward : sqrt_grad - api : square @@ -2542,6 +2554,7 @@ func : SqueezeWithXShapeInferMeta kernel : func : squeeze_with_xshape + inplace : (x -> out) view: (x -> out) intermediate : xshape backward : squeeze_grad @@ -2566,11 +2579,12 @@ - api : subtract args : (Tensor x, Tensor y) - output : Tensor + output : Tensor(out) infer_meta : func : ElementwiseInferMeta kernel : func : subtract + inplace : (x -> out) backward : subtract_grad - api : sum @@ -2640,11 +2654,12 @@ # tanh - api : tanh args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : tanh + inplace : (x -> out) backward : tanh_grad # tanh_shrink @@ -2817,6 +2832,7 @@ func : UnsqueezeWithXShapeInferMeta kernel : func : unsqueeze_with_xshape + inplace : (x -> out) view: (x -> out) intermediate : xshape backward : unsqueeze_grad diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 94e30a5e8a1..7588339e95e 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -302,6 +302,17 @@ class TestDygraphInplaceReshape(TestDygraphInplace): return paddle.reshape_(var, [-1]) +class TestDygraphInplaceReshapeTensor(TestDygraphInplace): + + def non_inplace_api_processing(self, var): + shape = paddle.to_tensor(-1) + return paddle.reshape(var, shape) + + def inplace_api_processing(self, var): + shape = paddle.to_tensor(-1) + return paddle.reshape_(var, shape) + + class TestDygraphInplaceFlatten(TestDygraphInplace): def non_inplace_api_processing(self, var): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 74fa0e70c72..a210f806fc4 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -136,6 +136,8 @@ def elu_(x, alpha=1.0, name=None): Please refer to :ref:`api_nn_cn_elu`. """ assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead." + if in_dygraph_mode(): + return _C_ops.final_state_elu_(x, alpha) return _C_ops.elu_(x, 'alpha', alpha) @@ -1146,7 +1148,16 @@ def softmax_(x, axis=-1, dtype=None, name=None): if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): dtype = convert_np_dtype_to_dtype_(dtype) use_cudnn = True - return _C_ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn) + + if in_dygraph_mode(): + outs_cast = x if dtype is None \ + else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) + return _C_ops.final_state_softmax_(outs_cast, axis) + + if _in_legacy_dygraph(): + outs_cast = x if dtype is None \ + else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) + return _C_ops.softmax_(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn) def softplus(x, beta=1, threshold=20, name=None): diff --git a/python/paddle/tensor/layer_function_generator.py b/python/paddle/tensor/layer_function_generator.py index c6e8df67dec..15d81624591 100644 --- a/python/paddle/tensor/layer_function_generator.py +++ b/python/paddle/tensor/layer_function_generator.py @@ -305,6 +305,10 @@ def generate_inplace_fn(inplace_op_type): origin_op_type = inplace_op_type[:-1] def func(x, name=None): + final_state_inplace_op_type = "final_state_%s" % inplace_op_type + if in_dygraph_mode() and hasattr(_C_ops, final_state_inplace_op_type): + op = getattr(_C_ops, final_state_inplace_op_type) + return op(x) if _non_static_mode(): op = getattr(_C_ops, inplace_op_type) return op(x) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 43837d03d3a..bf660d20141 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1524,9 +1524,14 @@ def flatten_(x, start_axis=0, stop_axis=-1, name=None): if start_axis > stop_axis: raise ValueError("The stop_axis should be larger than stat_axis") - dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis', start_axis, - 'stop_axis', stop_axis) - return dy_out + if in_dygraph_mode(): + return _C_ops.final_state_flatten_(x, start_axis, stop_axis) + + if _in_legacy_dygraph(): + dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis', + start_axis, 'stop_axis', + stop_axis) + return dy_out def roll(x, shifts, axis=None, name=None): @@ -2052,8 +2057,13 @@ def squeeze_(x, axis=None, name=None): elif isinstance(axis, tuple): axis = list(axis) - out, _ = _C_ops.squeeze2_(x, 'axes', axis) - return out + input = x + axes = axis + if in_dygraph_mode(): + return _C_ops.final_state_squeeze_(input, axes) + if _in_legacy_dygraph(): + out, _ = _C_ops.squeeze2_(input, 'axes', axes) + return out def unique_consecutive(x, @@ -2412,16 +2422,20 @@ def unsqueeze_(x, axis, name=None): Inplace version of ``unsqueeze`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_tensor_unsqueeze`. """ - if isinstance(axis, int): - axis = [axis] - elif isinstance(axis, Variable): - axis = axis.numpy().tolist() - elif isinstance(axis, (list, tuple)): - axis = [ + input = x + axes = axis + if isinstance(axes, int): + axes = [axes] + elif isinstance(axes, Variable): + axes = axes.numpy().tolist() + elif isinstance(axes, (list, tuple)): + axes = [ item.numpy().item(0) if isinstance(item, Variable) else item - for item in axis + for item in axes ] - out, _ = _C_ops.unsqueeze2_(x, 'axes', axis) + if in_dygraph_mode(): + return _C_ops.final_state_unsqueeze_(input, axes) + out, _ = _C_ops.unsqueeze2_(input, 'axes', axes) return out @@ -2679,6 +2693,8 @@ def scatter_(x, index, updates, overwrite=True, name=None): Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_tensor_scatter`. """ + if in_dygraph_mode(): + return _C_ops.final_state_scatter_(x, index, updates, overwrite) return _C_ops.scatter_(x, index, updates, 'overwrite', overwrite) @@ -3272,13 +3288,13 @@ def reshape(x, shape, name=None): ) if isinstance(shape, (list, tuple)): shape = [ - item.numpy().item(0) if isinstance(item, Variable) else item - for item in shape + item.numpy().item(0) + if isinstance(item, tmp_tensor_type) else item for item in shape ] out = _C_ops.final_state_reshape(x, shape) elif isinstance(shape, tmp_tensor_type): shape.stop_gradient = True - out, _ = _C_ops.reshape2(x, shape) + out = _C_ops.final_state_reshape(x, shape) else: raise ValueError( "shape must be an instance of `list`, `tuple` or `Variable`," @@ -3386,17 +3402,41 @@ def reshape_(x, shape, name=None): Inplace version of ``reshape`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_tensor_reshape`. """ - if isinstance(shape, (list, tuple)): - shape = [ - item.numpy().item(0) if isinstance(item, Variable) else item - for item in shape - ] - out, _ = _C_ops.reshape2_(x, None, 'shape', shape) - return out - elif isinstance(shape, Variable): - shape.stop_gradient = True - out, _ = _C_ops.reshape2_(x, shape) + if in_dygraph_mode(): + tmp_tensor_type = core.eager.Tensor + if isinstance(shape, (list, tuple)): + shape = [ + item.numpy().item(0) + if isinstance(item, tmp_tensor_type) else item for item in shape + ] + out = _C_ops.final_state_reshape_(x, shape) + elif isinstance(shape, tmp_tensor_type): + shape.stop_gradient = True + out = _C_ops.final_state_reshape_(x, shape) + else: + raise ValueError( + "shape must be an instance of `list`, `tuple` or `Variable`," + " got '{}.'".format(type(shape))) + return out + else: + if isinstance(shape, (list, tuple)): + shape = [ + item.numpy().item(0) if isinstance(item, Variable) else item + for item in shape + ] + out, _ = _C_ops.reshape2_(x, None, 'shape', shape) + return out + elif isinstance(shape, Variable): + shape.stop_gradient = True + # NOTE(pangyoki): Cannot support the case where the shape Tensor + # is negative. In the infer_shape stage, the input's dim will + # be changed to a negative number. + # Thus, convert Shape Tensor to list firstly and then call + # reshape inplace op. + shape_list = shape.numpy().tolist() + out, _ = _C_ops.reshape2_(x, None, 'shape', shape_list) + return out def gather_nd(x, index, name=None): @@ -4340,6 +4380,9 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'): if broadcast_shape: indices = paddle.broadcast_to(indices, broadcast_shape) values = paddle.broadcast_to(values, indices.shape) + if in_dygraph_mode(): + return _C_ops.final_state_put_along_axis_(arr, indices, values, axis, + reduce) return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce", reduce) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b161eeedc90..6199cd8120f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -565,9 +565,12 @@ def add_(x, y, name=None): if out_shape != x.shape: raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) - out = _elementwise_op_in_dygraph( - x, y, axis=axis, op_name=op_type) - return out + if in_dygraph_mode(): + return _C_ops.final_state_add_(x, y) + else: + out = _elementwise_op_in_dygraph( + x, y, axis=axis, op_name=op_type) + return out def subtract(x, y, name=None): @@ -650,9 +653,12 @@ def subtract_(x, y, name=None): if out_shape != x.shape: raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) - out = _elementwise_op_in_dygraph( - x, y, axis=axis, act=act, op_name='elementwise_sub_') - return out + if in_dygraph_mode(): + return _C_ops.final_state_subtract_(x, y) + else: + out = _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name='elementwise_sub_') + return out def divide(x, y, name=None): @@ -3499,6 +3505,8 @@ def tanh_(x, name=None): Inplace version of ``tanh`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_tensor_tanh`. """ + if in_dygraph_mode(): + return _C_ops.final_state_tanh_( x ) return _C_ops.tanh_(x) @@ -4076,6 +4084,8 @@ def lerp_(x, y, weight, name=None): out_shape = broadcast_shape(out_shape, weight.shape) if out_shape != x.shape: raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) + if in_dygraph_mode(): + return _C_ops.final_state_lerp_( x, y, weight) return _C_ops.lerp_(x, y, weight) def erfinv(x, name=None): @@ -4124,6 +4134,8 @@ def erfinv_(x, name=None): Please refer to :ref:`api_tensor_erfinv`. """ check_type(x, 'x', (paddle.Tensor, Variable), 'erfinv') + if in_dygraph_mode(): + return _C_ops.final_state_erfinv_( x ) return _C_ops.erfinv_(x) def rad2deg(x, name=None): -- GitLab