未验证 提交 7c1e7e46 编写于 作者: P pangyoki 提交者: GitHub

call final_state method in inplace APIs (#42968)

* add forward inplace final state api

* fix bug

* fix reshape

* fix coverage

* add inplace info for erfinv, lerp, put_along_axis

* fix put_along_axis infer_meta

* fix format

* update yaml

* fix
上级 3aed9690
...@@ -92,11 +92,12 @@ ...@@ -92,11 +92,12 @@
- api : add - api : add
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
func : add func : add
inplace : (x -> out)
backward : add_grad backward : add_grad
- api : add_n - api : add_n
...@@ -470,6 +471,7 @@ ...@@ -470,6 +471,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : ceil func : ceil
inplace : (x -> out)
backward : ceil_grad backward : ceil_grad
- api : celu - api : celu
...@@ -779,12 +781,13 @@ ...@@ -779,12 +781,13 @@
# elu # elu
- api : elu - api : elu
args : (Tensor x, float alpha) args : (Tensor x, float alpha)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : elu func : elu
inplace : (x -> out)
backward : elu_grad backward : elu_grad
- api : embedding - api : embedding
...@@ -836,11 +839,12 @@ ...@@ -836,11 +839,12 @@
# exp # exp
- api : exp - api : exp
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : exp func : exp
inplace : (x -> out)
backward : exp_grad backward : exp_grad
# expand # expand
...@@ -958,6 +962,7 @@ ...@@ -958,6 +962,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : floor func : floor
inplace : (x -> out)
backward : floor_grad backward : floor_grad
- api : floor_divide - api : floor_divide
...@@ -1432,11 +1437,12 @@ ...@@ -1432,11 +1437,12 @@
- api : lerp - api : lerp
args : (Tensor x, Tensor y, Tensor weight) args : (Tensor x, Tensor y, Tensor weight)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : LerpInferMeta func : LerpInferMeta
kernel : kernel :
func : lerp func : lerp
inplace : (x -> out)
backward : lerp_grad backward : lerp_grad
- api : less_equal - api : less_equal
...@@ -2052,13 +2058,14 @@ ...@@ -2052,13 +2058,14 @@
# put_along_axis # put_along_axis
- api : put_along_axis - api : put_along_axis
args : (Tensor x, Tensor index, Tensor value, int axis, str reduce) args : (Tensor x, Tensor index, Tensor value, int axis, str reduce)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
param : [index] param : [x]
kernel : kernel :
func : put_along_axis func : put_along_axis
data_type : x data_type : x
inplace : (x -> out)
backward : put_along_axis_grad backward : put_along_axis_grad
- api : qr - api : qr
...@@ -2105,11 +2112,12 @@ ...@@ -2105,11 +2112,12 @@
- api : reciprocal - api : reciprocal
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : reciprocal func : reciprocal
inplace : (x -> out)
backward : reciprocal_grad backward : reciprocal_grad
# reduce_prod # reduce_prod
...@@ -2253,6 +2261,7 @@ ...@@ -2253,6 +2261,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : round func : round
inplace : (x -> out)
backward : round_grad backward : round_grad
- api : rsqrt - api : rsqrt
...@@ -2279,12 +2288,13 @@ ...@@ -2279,12 +2288,13 @@
- api : scatter - api : scatter
args : (Tensor x, Tensor index, Tensor updates, bool overwrite) args : (Tensor x, Tensor index, Tensor updates, bool overwrite)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ScatterInferMeta func : ScatterInferMeta
dtype : x dtype : x
kernel : kernel :
func : scatter func : scatter
inplace : (x -> out)
backward : scatter_grad backward : scatter_grad
- api : scatter_nd_add - api : scatter_nd_add
...@@ -2463,12 +2473,13 @@ ...@@ -2463,12 +2473,13 @@
- api : softmax - api : softmax
args : (Tensor x, int axis) args : (Tensor x, int axis)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : SoftmaxInferMeta func : SoftmaxInferMeta
kernel : kernel :
func : softmax func : softmax
use_gpudnn : true use_gpudnn : true
inplace : (x -> out)
backward : softmax_grad backward : softmax_grad
- api : softplus - api : softplus
...@@ -2510,11 +2521,12 @@ ...@@ -2510,11 +2521,12 @@
- api : sqrt - api : sqrt
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : sqrt func : sqrt
inplace : (x -> out)
backward : sqrt_grad backward : sqrt_grad
- api : square - api : square
...@@ -2542,6 +2554,7 @@ ...@@ -2542,6 +2554,7 @@
func : SqueezeWithXShapeInferMeta func : SqueezeWithXShapeInferMeta
kernel : kernel :
func : squeeze_with_xshape func : squeeze_with_xshape
inplace : (x -> out)
view: (x -> out) view: (x -> out)
intermediate : xshape intermediate : xshape
backward : squeeze_grad backward : squeeze_grad
...@@ -2566,11 +2579,12 @@ ...@@ -2566,11 +2579,12 @@
- api : subtract - api : subtract
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
func : subtract func : subtract
inplace : (x -> out)
backward : subtract_grad backward : subtract_grad
- api : sum - api : sum
...@@ -2640,11 +2654,12 @@ ...@@ -2640,11 +2654,12 @@
# tanh # tanh
- api : tanh - api : tanh
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : tanh func : tanh
inplace : (x -> out)
backward : tanh_grad backward : tanh_grad
# tanh_shrink # tanh_shrink
...@@ -2817,6 +2832,7 @@ ...@@ -2817,6 +2832,7 @@
func : UnsqueezeWithXShapeInferMeta func : UnsqueezeWithXShapeInferMeta
kernel : kernel :
func : unsqueeze_with_xshape func : unsqueeze_with_xshape
inplace : (x -> out)
view: (x -> out) view: (x -> out)
intermediate : xshape intermediate : xshape
backward : unsqueeze_grad backward : unsqueeze_grad
......
...@@ -302,6 +302,17 @@ class TestDygraphInplaceReshape(TestDygraphInplace): ...@@ -302,6 +302,17 @@ class TestDygraphInplaceReshape(TestDygraphInplace):
return paddle.reshape_(var, [-1]) return paddle.reshape_(var, [-1])
class TestDygraphInplaceReshapeTensor(TestDygraphInplace):
def non_inplace_api_processing(self, var):
shape = paddle.to_tensor(-1)
return paddle.reshape(var, shape)
def inplace_api_processing(self, var):
shape = paddle.to_tensor(-1)
return paddle.reshape_(var, shape)
class TestDygraphInplaceFlatten(TestDygraphInplace): class TestDygraphInplaceFlatten(TestDygraphInplace):
def non_inplace_api_processing(self, var): def non_inplace_api_processing(self, var):
......
...@@ -136,6 +136,8 @@ def elu_(x, alpha=1.0, name=None): ...@@ -136,6 +136,8 @@ def elu_(x, alpha=1.0, name=None):
Please refer to :ref:`api_nn_cn_elu`. Please refer to :ref:`api_nn_cn_elu`.
""" """
assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead." assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead."
if in_dygraph_mode():
return _C_ops.final_state_elu_(x, alpha)
return _C_ops.elu_(x, 'alpha', alpha) return _C_ops.elu_(x, 'alpha', alpha)
...@@ -1146,7 +1148,16 @@ def softmax_(x, axis=-1, dtype=None, name=None): ...@@ -1146,7 +1148,16 @@ def softmax_(x, axis=-1, dtype=None, name=None):
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
use_cudnn = True use_cudnn = True
return _C_ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn)
if in_dygraph_mode():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.final_state_softmax_(outs_cast, axis)
if _in_legacy_dygraph():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.softmax_(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn)
def softplus(x, beta=1, threshold=20, name=None): def softplus(x, beta=1, threshold=20, name=None):
......
...@@ -305,6 +305,10 @@ def generate_inplace_fn(inplace_op_type): ...@@ -305,6 +305,10 @@ def generate_inplace_fn(inplace_op_type):
origin_op_type = inplace_op_type[:-1] origin_op_type = inplace_op_type[:-1]
def func(x, name=None): def func(x, name=None):
final_state_inplace_op_type = "final_state_%s" % inplace_op_type
if in_dygraph_mode() and hasattr(_C_ops, final_state_inplace_op_type):
op = getattr(_C_ops, final_state_inplace_op_type)
return op(x)
if _non_static_mode(): if _non_static_mode():
op = getattr(_C_ops, inplace_op_type) op = getattr(_C_ops, inplace_op_type)
return op(x) return op(x)
......
...@@ -1524,9 +1524,14 @@ def flatten_(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1524,9 +1524,14 @@ def flatten_(x, start_axis=0, stop_axis=-1, name=None):
if start_axis > stop_axis: if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis") raise ValueError("The stop_axis should be larger than stat_axis")
dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis', start_axis, if in_dygraph_mode():
'stop_axis', stop_axis) return _C_ops.final_state_flatten_(x, start_axis, stop_axis)
return dy_out
if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis',
start_axis, 'stop_axis',
stop_axis)
return dy_out
def roll(x, shifts, axis=None, name=None): def roll(x, shifts, axis=None, name=None):
...@@ -2052,8 +2057,13 @@ def squeeze_(x, axis=None, name=None): ...@@ -2052,8 +2057,13 @@ def squeeze_(x, axis=None, name=None):
elif isinstance(axis, tuple): elif isinstance(axis, tuple):
axis = list(axis) axis = list(axis)
out, _ = _C_ops.squeeze2_(x, 'axes', axis) input = x
return out axes = axis
if in_dygraph_mode():
return _C_ops.final_state_squeeze_(input, axes)
if _in_legacy_dygraph():
out, _ = _C_ops.squeeze2_(input, 'axes', axes)
return out
def unique_consecutive(x, def unique_consecutive(x,
...@@ -2412,16 +2422,20 @@ def unsqueeze_(x, axis, name=None): ...@@ -2412,16 +2422,20 @@ def unsqueeze_(x, axis, name=None):
Inplace version of ``unsqueeze`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``unsqueeze`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_unsqueeze`. Please refer to :ref:`api_paddle_tensor_unsqueeze`.
""" """
if isinstance(axis, int): input = x
axis = [axis] axes = axis
elif isinstance(axis, Variable): if isinstance(axes, int):
axis = axis.numpy().tolist() axes = [axes]
elif isinstance(axis, (list, tuple)): elif isinstance(axes, Variable):
axis = [ axes = axes.numpy().tolist()
elif isinstance(axes, (list, tuple)):
axes = [
item.numpy().item(0) if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
for item in axis for item in axes
] ]
out, _ = _C_ops.unsqueeze2_(x, 'axes', axis) if in_dygraph_mode():
return _C_ops.final_state_unsqueeze_(input, axes)
out, _ = _C_ops.unsqueeze2_(input, 'axes', axes)
return out return out
...@@ -2679,6 +2693,8 @@ def scatter_(x, index, updates, overwrite=True, name=None): ...@@ -2679,6 +2693,8 @@ def scatter_(x, index, updates, overwrite=True, name=None):
Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_scatter`. Please refer to :ref:`api_paddle_tensor_scatter`.
""" """
if in_dygraph_mode():
return _C_ops.final_state_scatter_(x, index, updates, overwrite)
return _C_ops.scatter_(x, index, updates, 'overwrite', overwrite) return _C_ops.scatter_(x, index, updates, 'overwrite', overwrite)
...@@ -3272,13 +3288,13 @@ def reshape(x, shape, name=None): ...@@ -3272,13 +3288,13 @@ def reshape(x, shape, name=None):
) )
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
shape = [ shape = [
item.numpy().item(0) if isinstance(item, Variable) else item item.numpy().item(0)
for item in shape if isinstance(item, tmp_tensor_type) else item for item in shape
] ]
out = _C_ops.final_state_reshape(x, shape) out = _C_ops.final_state_reshape(x, shape)
elif isinstance(shape, tmp_tensor_type): elif isinstance(shape, tmp_tensor_type):
shape.stop_gradient = True shape.stop_gradient = True
out, _ = _C_ops.reshape2(x, shape) out = _C_ops.final_state_reshape(x, shape)
else: else:
raise ValueError( raise ValueError(
"shape must be an instance of `list`, `tuple` or `Variable`," "shape must be an instance of `list`, `tuple` or `Variable`,"
...@@ -3386,17 +3402,41 @@ def reshape_(x, shape, name=None): ...@@ -3386,17 +3402,41 @@ def reshape_(x, shape, name=None):
Inplace version of ``reshape`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``reshape`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_reshape`. Please refer to :ref:`api_paddle_tensor_reshape`.
""" """
if isinstance(shape, (list, tuple)): if in_dygraph_mode():
shape = [ tmp_tensor_type = core.eager.Tensor
item.numpy().item(0) if isinstance(item, Variable) else item if isinstance(shape, (list, tuple)):
for item in shape shape = [
] item.numpy().item(0)
out, _ = _C_ops.reshape2_(x, None, 'shape', shape) if isinstance(item, tmp_tensor_type) else item for item in shape
return out ]
elif isinstance(shape, Variable): out = _C_ops.final_state_reshape_(x, shape)
shape.stop_gradient = True elif isinstance(shape, tmp_tensor_type):
out, _ = _C_ops.reshape2_(x, shape) shape.stop_gradient = True
out = _C_ops.final_state_reshape_(x, shape)
else:
raise ValueError(
"shape must be an instance of `list`, `tuple` or `Variable`,"
" got '{}.'".format(type(shape)))
return out return out
else:
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = _C_ops.reshape2_(x, None, 'shape', shape)
return out
elif isinstance(shape, Variable):
shape.stop_gradient = True
# NOTE(pangyoki): Cannot support the case where the shape Tensor
# is negative. In the infer_shape stage, the input's dim will
# be changed to a negative number.
# Thus, convert Shape Tensor to list firstly and then call
# reshape inplace op.
shape_list = shape.numpy().tolist()
out, _ = _C_ops.reshape2_(x, None, 'shape', shape_list)
return out
def gather_nd(x, index, name=None): def gather_nd(x, index, name=None):
...@@ -4340,6 +4380,9 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'): ...@@ -4340,6 +4380,9 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'):
if broadcast_shape: if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape) indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape) values = paddle.broadcast_to(values, indices.shape)
if in_dygraph_mode():
return _C_ops.final_state_put_along_axis_(arr, indices, values, axis,
reduce)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce", return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce) reduce)
......
...@@ -565,9 +565,12 @@ def add_(x, y, name=None): ...@@ -565,9 +565,12 @@ def add_(x, y, name=None):
if out_shape != x.shape: if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
out = _elementwise_op_in_dygraph( if in_dygraph_mode():
x, y, axis=axis, op_name=op_type) return _C_ops.final_state_add_(x, y)
return out else:
out = _elementwise_op_in_dygraph(
x, y, axis=axis, op_name=op_type)
return out
def subtract(x, y, name=None): def subtract(x, y, name=None):
...@@ -650,9 +653,12 @@ def subtract_(x, y, name=None): ...@@ -650,9 +653,12 @@ def subtract_(x, y, name=None):
if out_shape != x.shape: if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
out = _elementwise_op_in_dygraph( if in_dygraph_mode():
x, y, axis=axis, act=act, op_name='elementwise_sub_') return _C_ops.final_state_subtract_(x, y)
return out else:
out = _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_sub_')
return out
def divide(x, y, name=None): def divide(x, y, name=None):
...@@ -3499,6 +3505,8 @@ def tanh_(x, name=None): ...@@ -3499,6 +3505,8 @@ def tanh_(x, name=None):
Inplace version of ``tanh`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``tanh`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_tanh`. Please refer to :ref:`api_tensor_tanh`.
""" """
if in_dygraph_mode():
return _C_ops.final_state_tanh_( x )
return _C_ops.tanh_(x) return _C_ops.tanh_(x)
...@@ -4076,6 +4084,8 @@ def lerp_(x, y, weight, name=None): ...@@ -4076,6 +4084,8 @@ def lerp_(x, y, weight, name=None):
out_shape = broadcast_shape(out_shape, weight.shape) out_shape = broadcast_shape(out_shape, weight.shape)
if out_shape != x.shape: if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
if in_dygraph_mode():
return _C_ops.final_state_lerp_( x, y, weight)
return _C_ops.lerp_(x, y, weight) return _C_ops.lerp_(x, y, weight)
def erfinv(x, name=None): def erfinv(x, name=None):
...@@ -4124,6 +4134,8 @@ def erfinv_(x, name=None): ...@@ -4124,6 +4134,8 @@ def erfinv_(x, name=None):
Please refer to :ref:`api_tensor_erfinv`. Please refer to :ref:`api_tensor_erfinv`.
""" """
check_type(x, 'x', (paddle.Tensor, Variable), 'erfinv') check_type(x, 'x', (paddle.Tensor, Variable), 'erfinv')
if in_dygraph_mode():
return _C_ops.final_state_erfinv_( x )
return _C_ops.erfinv_(x) return _C_ops.erfinv_(x)
def rad2deg(x, name=None): def rad2deg(x, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册