未验证 提交 7c1e7e46 编写于 作者: P pangyoki 提交者: GitHub

call final_state method in inplace APIs (#42968)

* add forward inplace final state api

* fix bug

* fix reshape

* fix coverage

* add inplace info for erfinv, lerp, put_along_axis

* fix put_along_axis infer_meta

* fix format

* update yaml

* fix
上级 3aed9690
......@@ -92,11 +92,12 @@
- api : add
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : add
inplace : (x -> out)
backward : add_grad
- api : add_n
......@@ -470,6 +471,7 @@
func : UnchangedInferMeta
kernel :
func : ceil
inplace : (x -> out)
backward : ceil_grad
- api : celu
......@@ -779,12 +781,13 @@
# elu
- api : elu
args : (Tensor x, float alpha)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : elu
inplace : (x -> out)
backward : elu_grad
- api : embedding
......@@ -836,11 +839,12 @@
# exp
- api : exp
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : exp
inplace : (x -> out)
backward : exp_grad
# expand
......@@ -958,6 +962,7 @@
func : UnchangedInferMeta
kernel :
func : floor
inplace : (x -> out)
backward : floor_grad
- api : floor_divide
......@@ -1432,11 +1437,12 @@
- api : lerp
args : (Tensor x, Tensor y, Tensor weight)
output : Tensor
output : Tensor(out)
infer_meta :
func : LerpInferMeta
kernel :
func : lerp
inplace : (x -> out)
backward : lerp_grad
- api : less_equal
......@@ -2052,13 +2058,14 @@
# put_along_axis
- api : put_along_axis
args : (Tensor x, Tensor index, Tensor value, int axis, str reduce)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [index]
param : [x]
kernel :
func : put_along_axis
data_type : x
inplace : (x -> out)
backward : put_along_axis_grad
- api : qr
......@@ -2105,11 +2112,12 @@
- api : reciprocal
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : reciprocal
inplace : (x -> out)
backward : reciprocal_grad
# reduce_prod
......@@ -2253,6 +2261,7 @@
func : UnchangedInferMeta
kernel :
func : round
inplace : (x -> out)
backward : round_grad
- api : rsqrt
......@@ -2279,12 +2288,13 @@
- api : scatter
args : (Tensor x, Tensor index, Tensor updates, bool overwrite)
output : Tensor
output : Tensor(out)
infer_meta :
func : ScatterInferMeta
dtype : x
kernel :
func : scatter
inplace : (x -> out)
backward : scatter_grad
- api : scatter_nd_add
......@@ -2463,12 +2473,13 @@
- api : softmax
args : (Tensor x, int axis)
output : Tensor
output : Tensor(out)
infer_meta :
func : SoftmaxInferMeta
kernel :
func : softmax
use_gpudnn : true
inplace : (x -> out)
backward : softmax_grad
- api : softplus
......@@ -2510,11 +2521,12 @@
- api : sqrt
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sqrt
inplace : (x -> out)
backward : sqrt_grad
- api : square
......@@ -2542,6 +2554,7 @@
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
......@@ -2566,11 +2579,12 @@
- api : subtract
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : subtract
inplace : (x -> out)
backward : subtract_grad
- api : sum
......@@ -2640,11 +2654,12 @@
# tanh
- api : tanh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : tanh
inplace : (x -> out)
backward : tanh_grad
# tanh_shrink
......@@ -2817,6 +2832,7 @@
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
......
......@@ -302,6 +302,17 @@ class TestDygraphInplaceReshape(TestDygraphInplace):
return paddle.reshape_(var, [-1])
class TestDygraphInplaceReshapeTensor(TestDygraphInplace):
def non_inplace_api_processing(self, var):
shape = paddle.to_tensor(-1)
return paddle.reshape(var, shape)
def inplace_api_processing(self, var):
shape = paddle.to_tensor(-1)
return paddle.reshape_(var, shape)
class TestDygraphInplaceFlatten(TestDygraphInplace):
def non_inplace_api_processing(self, var):
......
......@@ -136,6 +136,8 @@ def elu_(x, alpha=1.0, name=None):
Please refer to :ref:`api_nn_cn_elu`.
"""
assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead."
if in_dygraph_mode():
return _C_ops.final_state_elu_(x, alpha)
return _C_ops.elu_(x, 'alpha', alpha)
......@@ -1146,7 +1148,16 @@ def softmax_(x, axis=-1, dtype=None, name=None):
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype)
use_cudnn = True
return _C_ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn)
if in_dygraph_mode():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.final_state_softmax_(outs_cast, axis)
if _in_legacy_dygraph():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.softmax_(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn)
def softplus(x, beta=1, threshold=20, name=None):
......
......@@ -305,6 +305,10 @@ def generate_inplace_fn(inplace_op_type):
origin_op_type = inplace_op_type[:-1]
def func(x, name=None):
final_state_inplace_op_type = "final_state_%s" % inplace_op_type
if in_dygraph_mode() and hasattr(_C_ops, final_state_inplace_op_type):
op = getattr(_C_ops, final_state_inplace_op_type)
return op(x)
if _non_static_mode():
op = getattr(_C_ops, inplace_op_type)
return op(x)
......
......@@ -1524,9 +1524,14 @@ def flatten_(x, start_axis=0, stop_axis=-1, name=None):
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")
dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis', start_axis,
'stop_axis', stop_axis)
return dy_out
if in_dygraph_mode():
return _C_ops.final_state_flatten_(x, start_axis, stop_axis)
if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis',
start_axis, 'stop_axis',
stop_axis)
return dy_out
def roll(x, shifts, axis=None, name=None):
......@@ -2052,8 +2057,13 @@ def squeeze_(x, axis=None, name=None):
elif isinstance(axis, tuple):
axis = list(axis)
out, _ = _C_ops.squeeze2_(x, 'axes', axis)
return out
input = x
axes = axis
if in_dygraph_mode():
return _C_ops.final_state_squeeze_(input, axes)
if _in_legacy_dygraph():
out, _ = _C_ops.squeeze2_(input, 'axes', axes)
return out
def unique_consecutive(x,
......@@ -2412,16 +2422,20 @@ def unsqueeze_(x, axis, name=None):
Inplace version of ``unsqueeze`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_unsqueeze`.
"""
if isinstance(axis, int):
axis = [axis]
elif isinstance(axis, Variable):
axis = axis.numpy().tolist()
elif isinstance(axis, (list, tuple)):
axis = [
input = x
axes = axis
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, Variable):
axes = axes.numpy().tolist()
elif isinstance(axes, (list, tuple)):
axes = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in axis
for item in axes
]
out, _ = _C_ops.unsqueeze2_(x, 'axes', axis)
if in_dygraph_mode():
return _C_ops.final_state_unsqueeze_(input, axes)
out, _ = _C_ops.unsqueeze2_(input, 'axes', axes)
return out
......@@ -2679,6 +2693,8 @@ def scatter_(x, index, updates, overwrite=True, name=None):
Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_scatter`.
"""
if in_dygraph_mode():
return _C_ops.final_state_scatter_(x, index, updates, overwrite)
return _C_ops.scatter_(x, index, updates, 'overwrite', overwrite)
......@@ -3272,13 +3288,13 @@ def reshape(x, shape, name=None):
)
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
item.numpy().item(0)
if isinstance(item, tmp_tensor_type) else item for item in shape
]
out = _C_ops.final_state_reshape(x, shape)
elif isinstance(shape, tmp_tensor_type):
shape.stop_gradient = True
out, _ = _C_ops.reshape2(x, shape)
out = _C_ops.final_state_reshape(x, shape)
else:
raise ValueError(
"shape must be an instance of `list`, `tuple` or `Variable`,"
......@@ -3386,17 +3402,41 @@ def reshape_(x, shape, name=None):
Inplace version of ``reshape`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_reshape`.
"""
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = _C_ops.reshape2_(x, None, 'shape', shape)
return out
elif isinstance(shape, Variable):
shape.stop_gradient = True
out, _ = _C_ops.reshape2_(x, shape)
if in_dygraph_mode():
tmp_tensor_type = core.eager.Tensor
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0)
if isinstance(item, tmp_tensor_type) else item for item in shape
]
out = _C_ops.final_state_reshape_(x, shape)
elif isinstance(shape, tmp_tensor_type):
shape.stop_gradient = True
out = _C_ops.final_state_reshape_(x, shape)
else:
raise ValueError(
"shape must be an instance of `list`, `tuple` or `Variable`,"
" got '{}.'".format(type(shape)))
return out
else:
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = _C_ops.reshape2_(x, None, 'shape', shape)
return out
elif isinstance(shape, Variable):
shape.stop_gradient = True
# NOTE(pangyoki): Cannot support the case where the shape Tensor
# is negative. In the infer_shape stage, the input's dim will
# be changed to a negative number.
# Thus, convert Shape Tensor to list firstly and then call
# reshape inplace op.
shape_list = shape.numpy().tolist()
out, _ = _C_ops.reshape2_(x, None, 'shape', shape_list)
return out
def gather_nd(x, index, name=None):
......@@ -4340,6 +4380,9 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'):
if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
if in_dygraph_mode():
return _C_ops.final_state_put_along_axis_(arr, indices, values, axis,
reduce)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce)
......
......@@ -565,9 +565,12 @@ def add_(x, y, name=None):
if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
out = _elementwise_op_in_dygraph(
x, y, axis=axis, op_name=op_type)
return out
if in_dygraph_mode():
return _C_ops.final_state_add_(x, y)
else:
out = _elementwise_op_in_dygraph(
x, y, axis=axis, op_name=op_type)
return out
def subtract(x, y, name=None):
......@@ -650,9 +653,12 @@ def subtract_(x, y, name=None):
if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
out = _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_sub_')
return out
if in_dygraph_mode():
return _C_ops.final_state_subtract_(x, y)
else:
out = _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_sub_')
return out
def divide(x, y, name=None):
......@@ -3499,6 +3505,8 @@ def tanh_(x, name=None):
Inplace version of ``tanh`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_tanh`.
"""
if in_dygraph_mode():
return _C_ops.final_state_tanh_( x )
return _C_ops.tanh_(x)
......@@ -4076,6 +4084,8 @@ def lerp_(x, y, weight, name=None):
out_shape = broadcast_shape(out_shape, weight.shape)
if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
if in_dygraph_mode():
return _C_ops.final_state_lerp_( x, y, weight)
return _C_ops.lerp_(x, y, weight)
def erfinv(x, name=None):
......@@ -4124,6 +4134,8 @@ def erfinv_(x, name=None):
Please refer to :ref:`api_tensor_erfinv`.
"""
check_type(x, 'x', (paddle.Tensor, Variable), 'erfinv')
if in_dygraph_mode():
return _C_ops.final_state_erfinv_( x )
return _C_ops.erfinv_(x)
def rad2deg(x, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册