未验证 提交 58a03d41 编写于 作者: G GGBond8488 提交者: GitHub

【inplace api】batch add inplace api paddle.log_, paddle.i0_,...

【inplace api】batch add inplace api paddle.log_, paddle.i0_, paddle.nn.functional.leaky_relu_... (#55576)

* batch add inplace api

* add inplace test

* add activation inplace

* fix test

* remove atan2 ge, gt, le, lt, nq

* remove atan2 ge, gt, le, lt, nq

* fix windows ci error

* rerun ci

* fix typro

* fix bugs

---------
Co-authored-by: Nzhangrui34 <v_zhangrui34@baidu.com>
上级 da258964
......@@ -907,6 +907,7 @@
func : TrilInferMeta
kernel :
func : tril
inplace: (x -> out)
backward : tril_grad
- op : tril_indices
......@@ -928,6 +929,7 @@
func : TriuInferMeta
kernel :
func : triu
inplace: (x -> out)
backward : triu_grad
- op : triu_indices
......
......@@ -665,11 +665,12 @@
- op : digamma
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : digamma
inplace: (x -> out)
backward : digamma_grad
- op : dirichlet
......@@ -1107,12 +1108,13 @@
- op : hardtanh
args : (Tensor x, float t_min=0, float t_max=24)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hardtanh
inplace: (x -> out)
backward : hardtanh_grad
- op : heaviside
......@@ -1149,6 +1151,7 @@
func : UnchangedInferMeta
kernel :
func : i0
inplace: (x -> out)
backward : i0_grad
- op : i0e
......@@ -1361,12 +1364,13 @@
- op : leaky_relu
args : (Tensor x, float negative_slope = 0.02f)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : leaky_relu
inplace: (x -> out)
backward : leaky_relu_grad
- op : lerp
......@@ -1386,6 +1390,7 @@
func : UnchangedInferMeta
kernel :
func : lgamma
inplace: (x -> out)
backward : lgamma_grad
- op : linear_interp
......@@ -1413,38 +1418,42 @@
- op : log
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log
inplace: (x -> out)
backward: log_grad
- op : log10
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log10
inplace: (x -> out)
backward: log10_grad
- op : log1p
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log1p
inplace: (x -> out)
backward: log1p_grad
- op : log2
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log2
inplace: (x -> out)
backward: log2_grad
- op : log_loss
......@@ -1517,12 +1526,13 @@
- op : logit
args : (Tensor x, float eps = 1e-6f)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : logit
inplace: (x -> out)
backward : logit_grad
- op : logsigmoid
......@@ -1895,6 +1905,7 @@
param: [x]
kernel :
func : polygamma
inplace: (x -> out)
backward : polygamma_grad
- op : pow
......@@ -2494,12 +2505,13 @@
- op : thresholded_relu
args : (Tensor x, float threshold = 1.0)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : thresholded_relu
inplace: (x -> out)
backward : thresholded_relu_grad
- op : topk
......@@ -2546,11 +2558,12 @@
- op : trunc
args : (Tensor input)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : trunc
inplace: (input -> out)
backward : trunc_grad
- op : unbind
......
......@@ -2032,7 +2032,7 @@ struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log<U>()).eval();
}
};
......@@ -2076,7 +2076,7 @@ struct Log2Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log2<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log2<U>()).eval();
}
};
......@@ -2121,7 +2121,7 @@ struct Log10Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log10<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log10<U>()).eval();
}
};
......@@ -2166,7 +2166,7 @@ struct Log1pFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>()).eval();
}
};
......
......@@ -110,7 +110,9 @@ from .tensor.creation import arange # noqa: F401
from .tensor.creation import full # noqa: F401
from .tensor.creation import full_like # noqa: F401
from .tensor.creation import triu # noqa: F401
from .tensor.creation import triu_ # noqa: F401
from .tensor.creation import tril # noqa: F401
from .tensor.creation import tril_ # noqa: F401
from .tensor.creation import meshgrid # noqa: F401
from .tensor.creation import empty # noqa: F401
from .tensor.creation import empty_like # noqa: F401
......@@ -224,14 +226,18 @@ from .tensor.math import cummin # noqa: F401
from .tensor.math import cumprod # noqa: F401
from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401
from .tensor.math import logit_ # noqa: F401
from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401
from .tensor.math import expm1_ # noqa: F401
from .tensor.math import floor # noqa: F401
from .tensor.math import increment # noqa: F401
from .tensor.math import log # noqa: F401
from .tensor.math import log_ # noqa: F401
from .tensor.math import log2_ # noqa: F401
from .tensor.math import log2 # noqa: F401
from .tensor.math import log10 # noqa: F401
from .tensor.math import log10_ # noqa: F401
from .tensor.math import multiplex # noqa: F401
from .tensor.math import pow # noqa: F401
from .tensor.math import pow_ # noqa: F401
......@@ -279,6 +285,7 @@ from .tensor.math import logsumexp # noqa: F401
from .tensor.math import logaddexp # noqa: F401
from .tensor.math import inverse # noqa: F401
from .tensor.math import log1p # noqa: F401
from .tensor.math import log1p_ # noqa: F401
from .tensor.math import erf # noqa: F401
from .tensor.math import erf_ # noqa: F401
from .tensor.math import addmm # noqa: F401
......@@ -294,9 +301,13 @@ from .tensor.math import prod # noqa: F401
from .tensor.math import broadcast_shape # noqa: F401
from .tensor.math import conj # noqa: F401
from .tensor.math import trunc # noqa: F401
from .tensor.math import trunc_ # noqa: F401
from .tensor.math import digamma # noqa: F401
from .tensor.math import digamma_ # noqa: F401
from .tensor.math import neg # noqa: F401
from .tensor.math import neg_ # noqa: F401
from .tensor.math import lgamma # noqa: F401
from .tensor.math import lgamma_ # noqa: F401
from .tensor.math import acosh # noqa: F401
from .tensor.math import acosh_ # noqa: F401
from .tensor.math import asinh # noqa: F401
......@@ -317,6 +328,7 @@ from .tensor.math import inner # noqa: F401
from .tensor.math import outer # noqa: F401
from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401
from .tensor.math import frac_ # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401
......@@ -326,10 +338,12 @@ from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.math import vander # noqa: F401
from .tensor.math import nextafter # noqa: F401
from .tensor.math import i0 # noqa: F401
from .tensor.math import i0_ # noqa: F401
from .tensor.math import i0e # noqa: F401
from .tensor.math import i1 # noqa: F401
from .tensor.math import i1e # noqa: F401
from .tensor.math import polygamma # noqa: F401
from .tensor.math import polygamma_ # noqa: F401
from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
......@@ -473,6 +487,7 @@ __all__ = [ # noqa
'logaddexp',
'logcumsumexp',
'logit',
'logit_',
'LazyGuard',
'sign',
'is_empty',
......@@ -561,6 +576,7 @@ __all__ = [ # noqa
'rand',
'less_equal',
'triu',
'triu_',
'sin',
'sin_',
'dist',
......@@ -582,6 +598,7 @@ __all__ = [ # noqa
'abs',
'abs_',
'tril',
'tril_',
'pow',
'pow_',
'zeros_like',
......@@ -608,7 +625,9 @@ __all__ = [ # noqa
'broadcast_shape',
'conj',
'neg',
'neg_',
'lgamma',
'lgamma_',
'lerp',
'erfinv',
'inner',
......@@ -693,13 +712,19 @@ __all__ = [ # noqa
'floor',
'cosh',
'log',
'log_',
'log2',
'log2_',
'log10',
'log10_',
'concat',
'check_shape',
'trunc',
'trunc_',
'frac',
'frac_',
'digamma',
'digamma_',
'standard_normal',
'diagonal',
'broadcast_tensors',
......@@ -741,8 +766,10 @@ __all__ = [ # noqa
'unflatten',
'nextafter',
'i0',
'i0_',
'i0e',
'i1',
'i1e',
'polygamma',
'polygamma_',
]
......@@ -21,9 +21,11 @@ from .activation import elu_ # noqa: F401
from .activation import gelu # noqa: F401
from .activation import hardshrink # noqa: F401
from .activation import hardtanh # noqa: F401
from .activation import hardtanh_ # noqa: F401
from .activation import hardsigmoid # noqa: F401
from .activation import hardswish # noqa: F401
from .activation import leaky_relu # noqa: F401
from .activation import leaky_relu_ # noqa: F401
from .activation import log_sigmoid # noqa: F401
from .activation import maxout # noqa: F401
from .activation import prelu # noqa: F401
......@@ -44,6 +46,7 @@ from .activation import tanh # noqa: F401
from .activation import tanh_ # noqa: F401
from .activation import tanhshrink # noqa: F401
from .activation import thresholded_relu # noqa: F401
from .activation import thresholded_relu_ # noqa: F401
from .activation import log_softmax # noqa: F401
from .activation import glu # noqa: F401
from .activation import gumbel_softmax # noqa: F401
......@@ -153,9 +156,11 @@ __all__ = [ # noqa
'gelu',
'hardshrink',
'hardtanh',
'hardtanh_',
'hardsigmoid',
'hardswish',
'leaky_relu',
'leaky_relu_',
'log_sigmoid',
'maxout',
'prelu',
......@@ -176,6 +181,7 @@ __all__ = [ # noqa
'tanh_',
'tanhshrink',
'thresholded_relu',
'thresholded_relu_',
'log_softmax',
'glu',
'gumbel_softmax',
......
......@@ -299,6 +299,16 @@ def hardtanh(x, min=-1.0, max=1.0, name=None):
return out
@inplace_apis_in_dygraph_only
def hardtanh_(x, min=-1.0, max=1.0, name=None):
r"""
Inplace version of ``hardtanh`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`paddle_nn_functional_hardtanh`.
"""
if in_dynamic_mode():
return _C_ops.hardtanh_(x, min, max)
def hardsigmoid(x, slope=0.1666667, offset=0.5, name=None):
r"""
hardsigmoid activation. Calculate the `hardsigmoid` of input `x`.
......@@ -458,6 +468,16 @@ def leaky_relu(x, negative_slope=0.01, name=None):
return out
@inplace_apis_in_dygraph_only
def leaky_relu_(x, negative_slope=0.01, name=None):
r"""
Inplace version of ``leaky_relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`paddle_nn_functional_leaky_relu`.
"""
if in_dynamic_mode():
return _C_ops.leaky_relu_(x, negative_slope)
def prelu(x, weight, data_format="NCHW", name=None):
"""
prelu activation. The calculation formula is follows:
......@@ -1498,6 +1518,16 @@ def thresholded_relu(x, threshold=1.0, name=None):
return out
@inplace_apis_in_dygraph_only
def thresholded_relu_(x, threshold=1.0, name=None):
r"""
Inplace version of ``thresholded_relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`paddle_nn_functional_thresholded_relu`.
"""
if in_dynamic_mode():
return _C_ops.thresholded_relu_(x, threshold)
def log_softmax(x, axis=-1, dtype=None, name=None):
r"""
This operator implements the log_softmax layer. The calculation process is
......
......@@ -35,7 +35,9 @@ from .creation import arange # noqa: F401
from .creation import full # noqa: F401
from .creation import full_like # noqa: F401
from .creation import triu # noqa: F401
from .creation import triu_ # noqa: F401
from .creation import tril # noqa: F401
from .creation import tril_ # noqa: F401
from .creation import meshgrid # noqa: F401
from .creation import empty # noqa: F401
from .creation import empty_like # noqa: F401
......@@ -162,6 +164,7 @@ from .math import cummin # noqa: F401
from .math import cumprod # noqa: F401
from .math import logcumsumexp # noqa: F401
from .math import logit # noqa: F401
from .math import logit_ # noqa: F401
from .math import exp # noqa: F401
from .math import exp_ # noqa: F401
from .math import expm1 # noqa: F401
......@@ -169,6 +172,7 @@ from .math import floor # noqa: F401
from .math import floor_ # noqa: F401
from .math import increment # noqa: F401
from .math import log # noqa: F401
from .math import log_ # noqa: F401
from .math import multiplex # noqa: F401
from .math import pow # noqa: F401
from .math import pow_ # noqa: F401
......@@ -221,8 +225,11 @@ from .math import logsumexp # noqa: F401
from .math import logaddexp # noqa: F401
from .math import inverse # noqa: F401
from .math import log2 # noqa: F401
from .math import log2_ # noqa: F401
from .math import log10 # noqa: F401
from .math import log10_ # noqa: F401
from .math import log1p # noqa: F401
from .math import log1p_ # noqa: F401
from .math import erf # noqa: F401
from .math import addmm # noqa: F401
from .math import addmm_ # noqa: F401
......@@ -239,9 +246,13 @@ from .math import any # noqa: F401
from .math import broadcast_shape # noqa: F401
from .math import conj # noqa: F401
from .math import trunc # noqa: F401
from .math import trunc_ # noqa: F401
from .math import digamma # noqa: F401
from .math import digamma_ # noqa: F401
from .math import neg # noqa: F401
from .math import neg_ # noqa: F401
from .math import lgamma # noqa: F401
from .math import lgamma_ # noqa: F401
from .math import diagonal # noqa: F401
from .math import acosh # noqa: F401
from .math import acosh_ # noqa: F401
......@@ -265,6 +276,7 @@ from .math import inner # noqa: F401
from .math import outer # noqa: F401
from .math import heaviside # noqa: F401
from .math import frac # noqa: F401
from .math import frac_ # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .math import frexp # noqa: F401
......@@ -276,10 +288,12 @@ from .math import sigmoid_ # noqa: F401
from .math import vander # noqa: F401
from .math import nextafter # noqa: F401
from .math import i0 # noqa: F401
from .math import i0_ # noqa: F401
from .math import i0e # noqa: F401
from .math import i1 # noqa: F401
from .math import i1e # noqa: F401
from .math import polygamma # noqa: F401
from .math import polygamma_ # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -367,6 +381,7 @@ tensor_method_func = [ # noqa
'cumprod',
'logcumsumexp',
'logit',
'logit_',
'exp',
'exp_',
'expm1',
......@@ -375,8 +390,11 @@ tensor_method_func = [ # noqa
'increment',
'logaddexp',
'log',
'log_',
'log2',
'log2_',
'log10',
'log10_',
'logsumexp',
'multiplex',
'pow',
......@@ -432,6 +450,7 @@ tensor_method_func = [ # noqa
'logsumexp',
'inverse',
'log1p',
'log1p_',
'erf',
'addmm',
'addmm_',
......@@ -446,19 +465,26 @@ tensor_method_func = [ # noqa
'broadcast_shape',
'conj',
'neg',
'neg_',
'lgamma',
'lgamma_',
'equal',
'equal_all',
'greater_equal',
'greater_equal_',
'greater_than',
'greater_than_',
'is_empty',
'less_equal',
'less_equal_',
'less_than',
'less_than_',
'logical_and',
'logical_not',
'logical_or',
'logical_xor',
'not_equal',
'not_equal_',
'allclose',
'isclose',
'is_tensor',
......@@ -525,9 +551,12 @@ tensor_method_func = [ # noqa
'imag',
'is_floating_point',
'digamma',
'digamma_',
'diagonal',
'trunc',
'trunc_',
'frac',
'frac_',
'bitwise_and',
'bitwise_or',
'bitwise_xor',
......@@ -583,10 +612,12 @@ tensor_method_func = [ # noqa
'nextafter',
'unflatten',
'i0',
'i0_',
'i0e',
'i1',
'i1e',
'polygamma',
'polygamma_',
]
# this list used in math_op_patch.py for magic_method bind
......
......@@ -22,6 +22,7 @@ import numpy as np
import paddle
from paddle import _C_ops
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
from ..fluid.data_feeder import (
check_dtype,
......@@ -1462,6 +1463,17 @@ def tril(x, diagonal=0, name=None):
return _tril_triu_op(LayerHelper('tril', **locals()))
@inplace_apis_in_dygraph_only
def tril_(x, diagonal=0, name=None):
r"""
Inplace version of ``tril`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tril`.
"""
if in_dynamic_mode():
return _C_ops.tril_(x, diagonal)
def triu(x, diagonal=0, name=None):
r"""
Return the upper triangular part of a matrix (2-D tensor) or batch of matrices
......@@ -1524,6 +1536,17 @@ def triu(x, diagonal=0, name=None):
return _tril_triu_op(LayerHelper('triu', **locals()))
@inplace_apis_in_dygraph_only
def triu_(x, diagonal=0, name=None):
r"""
Inplace version of ``triu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_triu`.
"""
if in_dynamic_mode():
return _C_ops.triu_(x, diagonal)
def meshgrid(*args, **kwargs):
"""
......
......@@ -187,6 +187,17 @@ def log(x, name=None):
return out
@inplace_apis_in_dygraph_only
def log_(x, name=None):
r"""
Inplace version of ``log`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_log`.
"""
if in_dynamic_mode():
return _C_ops.log_(x)
def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
Scale operator.
......@@ -1821,6 +1832,16 @@ def trunc(input, name=None):
return out
@inplace_apis_in_dygraph_only
def trunc_(input, name=None):
r"""
Inplace version of ``trunc`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_trunc`.
"""
if in_dynamic_mode():
return _C_ops.trunc_(input)
def mm(input, mat2, name=None):
"""
......@@ -2877,6 +2898,17 @@ def log1p(x, name=None):
return out
@inplace_apis_in_dygraph_only
def log1p_(x, name=None):
r"""
Inplace version of ``log1p`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_log1p`.
"""
if in_dynamic_mode():
return _C_ops.log1p_(x)
def log2(x, name=None):
r"""
Calculates the log to the base 2 of the given input tensor, element-wise.
......@@ -2932,6 +2964,17 @@ def log2(x, name=None):
return out
@inplace_apis_in_dygraph_only
def log2_(x, name=None):
r"""
Inplace version of ``log2`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_log2`.
"""
if in_dynamic_mode():
return _C_ops.log2_(x)
def log10(x, name=None):
r"""
Calculates the log to the base 10 of the given input tensor, element-wise.
......@@ -2987,6 +3030,17 @@ def log10(x, name=None):
return out
@inplace_apis_in_dygraph_only
def log10_(x, name=None):
r"""
Inplace version of ``log10`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_log10`.
"""
if in_dynamic_mode():
return _C_ops.log10_(x)
def clip(x, min=None, max=None, name=None):
"""
This operator clip all elements in input into the range [ min, max ] and return
......@@ -4385,6 +4439,16 @@ def digamma(x, name=None):
return out
@inplace_apis_in_dygraph_only
def digamma_(x, name=None):
r"""
Inplace version of ``digamma`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_digamma`.
"""
if in_dynamic_mode():
return _C_ops.digamma_(x)
def lgamma(x, name=None):
r"""
Calculates the lgamma of the given input tensor, element-wise.
......@@ -4422,6 +4486,16 @@ def lgamma(x, name=None):
return out
@inplace_apis_in_dygraph_only
def lgamma_(x, name=None):
r"""
Inplace version of ``lgamma`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_lgamma`.
"""
if in_dynamic_mode():
return _C_ops.lgamma_(x)
def neg(x, name=None):
"""
This function computes the negative of the Tensor elementwisely.
......@@ -4449,6 +4523,17 @@ def neg(x, name=None):
)
@inplace_apis_in_dygraph_only
def neg_(x, name=None):
r"""
Inplace version of ``neg`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_neg`.
"""
return x.scale_(
scale=-1.0, bias=0.0, bias_after_scale=True, act=None, name=name
)
def atan2(x, y, name=None):
r"""
Element-wise arctangent of x/y with consideration of the quadrant.
......@@ -4574,6 +4659,18 @@ def logit(x, eps=None, name=None):
return out
@inplace_apis_in_dygraph_only
def logit_(x, eps=None, name=None):
r"""
Inplace version of ``logit`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_logit`.
"""
if eps is None:
eps = 0.0
if in_dynamic_mode():
return _C_ops.logit_(x, eps)
def lerp(x, y, weight, name=None):
r"""
Does a linear interpolation between x and y based on weight.
......@@ -5322,6 +5419,29 @@ def frac(x, name=None):
return _elementwise_op(LayerHelper('elementwise_sub', **locals()))
@inplace_apis_in_dygraph_only
def frac_(x, name=None):
r"""
Inplace version of ``frac`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_frac`.
"""
if x.dtype not in [
paddle.int32,
paddle.int64,
paddle.float32,
paddle.float64,
]:
raise TypeError(
"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {}".format(
x.dtype
)
)
if in_dynamic_mode():
y = _C_ops.trunc(x)
return _C_ops.subtract_(x, y)
def sgn(x, name=None):
"""
For complex tensor, this API returns a new tensor whose elements have the same angles as the corresponding
......@@ -5884,6 +6004,17 @@ def i0(x, name=None):
return out
@inplace_apis_in_dygraph_only
def i0_(x, name=None):
r"""
Inplace version of ``i0`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_i0`.
"""
if in_dynamic_mode():
return _C_ops.i0_(x)
def i0e(x, name=None):
r"""
The function used to calculate exponentially scaled modified Bessel function of order 0.
......@@ -6046,6 +6177,27 @@ def polygamma(x, n, name=None):
return out
def polygamma_(x, n, name=None):
r"""
Inplace version of ``polygamma`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_polygamma`.
"""
if not isinstance(n, int):
raise TypeError(
"The input of n must be int type, but received: %s " % (type(n))
)
if n < 0:
raise ValueError(
"The input of n must be greater than or equal to 0. But received n = %s"
% (n)
)
if n == 0:
return digamma_(x)
else:
if in_dynamic_mode():
return _C_ops.polygamma_(x, n)
def ldexp(x, y, name=None):
"""
Compute the result of multiplying x by 2 to the power of y. The equation is:
......
......@@ -697,5 +697,141 @@ class TestDygraphInplacePowerScalar(TestDygraphInplaceWithContinuous):
paddle.pow_(var, [2])
class TestDygraphInplaceTriu(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.triu_(var, 0)
def non_inplace_api_processing(self, var):
return paddle.triu(var, 0)
class TestDygraphInplaceTril(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.tril_(var, 0)
def non_inplace_api_processing(self, var):
return paddle.tril(var, 0)
class TestDygraphInplaceLogit(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.logit_(var, 1e-3)
def non_inplace_api_processing(self, var):
return paddle.logit(var, 1e-3)
class TestDygraphInplaceLog(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.log_(var)
def non_inplace_api_processing(self, var):
return paddle.log(var)
class TestDygraphInplaceLog2(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.log2_(var)
def non_inplace_api_processing(self, var):
return paddle.log2(var)
class TestDygraphInplaceLog10(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.log10_(var)
def non_inplace_api_processing(self, var):
return paddle.log10(var)
class TestDygraphInplaceLog1p(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.log1p_(var)
def non_inplace_api_processing(self, var):
return paddle.log1p(var)
class TestDygraphInplaceTrunc(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.trunc_(var)
def non_inplace_api_processing(self, var):
return paddle.trunc(var)
class TestDygraphInplaceDigamma(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.digamma_(var)
def non_inplace_api_processing(self, var):
return paddle.digamma(var)
class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.neg_(var)
def non_inplace_api_processing(self, var):
return paddle.neg(var)
class TestDygraphInplaceLgamma(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.lgamma_(var)
def non_inplace_api_processing(self, var):
return paddle.lgamma(var)
class TestDygraphInplaceFrac(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.frac_(var)
def non_inplace_api_processing(self, var):
return paddle.frac(var)
class TestDygraphInplaceI0(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.i0_(var)
def non_inplace_api_processing(self, var):
return paddle.i0(var)
class TestDygraphInplacePolygamma(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.polygamma_(var, 1)
def non_inplace_api_processing(self, var):
return paddle.polygamma(var, 1)
class TestDygraphInplaceHardTanh(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.nn.functional.hardtanh_(var, -1.0, 1.0)
def non_inplace_api_processing(self, var):
return paddle.nn.functional.hardtanh(var, -1.0, 1.0)
class TestDygraphInplaceLeakyRelu(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.nn.functional.leaky_relu_(var, 0.01)
def non_inplace_api_processing(self, var):
return paddle.nn.functional.leaky_relu(var, 0.01)
class TestDygraphInplaceThresholdedRelu(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.nn.functional.thresholded_relu_(var, 1.0)
def non_inplace_api_processing(self, var):
return paddle.nn.functional.thresholded_relu(var, 1.0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册