未验证 提交 19302938 编写于 作者: G GGBond8488 提交者: GitHub

batch add inpalce api (#55078)

* batch add inpalce api

* fix inplace fn generate

* add test for  new inpalce api

* fix typro

* fix typro

* fix typro

* fix test error

* fix atan2

* remove atan2

* auto genereate inpalce api

* fix inplace generate fn error

* fix windows error

* fix test error

* fix test error

* fix windows ci error

* fix test error

* fix test_error

* fix test error

* fix eigen aliasing error in inplace

* remove elementwise_pow inplace

* fix doc error

* fix test error
上级 5e6645d7
......@@ -237,7 +237,6 @@
func : ElementwiseInferMeta
kernel :
func : elementwise_pow
inplace: (x -> out)
backward : elementwise_pow_grad
- op : embedding
......
......@@ -13,6 +13,7 @@
kernel :
func : abs
data_type : x
inplace: (x -> out)
backward : abs_grad
- op : accuracy
......@@ -26,20 +27,22 @@
- op : acos
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : acos
inplace: (x -> out)
backward : acos_grad
- op : acosh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : acosh
inplace: (x -> out)
backward : acosh_grad
- op : adagrad_
......@@ -90,12 +93,13 @@
- op : addmm
args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0)
output : Tensor
output : Tensor(out)
infer_meta :
func : AddmmInferMeta
kernel :
func : addmm
data_type : x
inplace: (input -> out)
backward : addmm_grad
- op : affine_grid
......@@ -176,34 +180,37 @@
- op : asin
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : asin
inplace: (x -> out)
backward : asin_grad
- op : asinh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : asinh
inplace: (x -> out)
backward : asinh_grad
- op : atan
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : atan
inplace: (x -> out)
backward : atan_grad
- op : atan2
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : Atan2InferMeta
kernel :
......@@ -212,11 +219,12 @@
- op : atanh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : atanh
inplace: (x -> out)
backward : atanh_grad
- op : auc
......@@ -524,20 +532,22 @@
- op : cos
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : cos
inplace: (x -> out)
backward : cos_grad
- op : cosh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : cosh
inplace: (x -> out)
backward : cosh_grad
- op : crop
......@@ -756,11 +766,12 @@
- op : erf
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : erf
inplace : (x -> out)
backward : erf_grad
- op : erfinv
......@@ -806,12 +817,13 @@
- op : expm1
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : expm1
inplace: (x -> out)
backward : expm1_grad
- op : fft_c2c
......@@ -2250,20 +2262,22 @@
- op : sin
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sin
inplace: (x -> out)
backward : sin_grad
- op : sinh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sinh
inplace: (x -> out)
backward : sinh_grad
- op : slogdet
......@@ -2409,11 +2423,12 @@
- op : tan
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : tan
inplace: (x -> out)
backward : tan_grad
- op : tanh
......
......@@ -116,7 +116,10 @@ template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sine<T>());
// Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
// will give wrong result, details see
// http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
out.device(d) = x.unaryExpr(Sine<T>()).eval();
}
};
......@@ -448,7 +451,7 @@ template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosine<T>());
out.device(d) = x.unaryExpr(Cosine<T>()).eval();
}
};
......@@ -762,7 +765,10 @@ template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Tangent<T>());
// Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
// will give wrong result, details see
// http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
out.device(d) = x.unaryExpr(Tangent<T>()).eval();
}
};
......@@ -795,7 +801,7 @@ template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sinh<T>());
out.device(d) = x.unaryExpr(Sinh<T>()).eval();
}
};
......@@ -804,7 +810,7 @@ template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosh<T>());
out.device(d) = x.unaryExpr(Cosh<T>()).eval();
}
};
......@@ -855,7 +861,7 @@ template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acos<T>());
out.device(d) = x.unaryExpr(Acos<T>()).eval();
}
};
......@@ -892,7 +898,7 @@ template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asin<T>());
out.device(d) = x.unaryExpr(Asin<T>()).eval();
}
};
......@@ -929,7 +935,7 @@ template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atan<T>());
out.device(d) = x.unaryExpr(Atan<T>()).eval();
}
};
......@@ -977,7 +983,7 @@ template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acosh<T>());
out.device(d) = x.unaryExpr(Acosh<T>()).eval();
}
};
......@@ -1014,7 +1020,7 @@ template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asinh<T>());
out.device(d) = x.unaryExpr(Asinh<T>()).eval();
}
};
......@@ -1051,7 +1057,7 @@ template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atanh<T>());
out.device(d) = x.unaryExpr(Atanh<T>()).eval();
}
};
......
......@@ -203,14 +203,21 @@ from .tensor.manipulation import index_put # noqa: F401
from .tensor.manipulation import index_put_ # noqa: F401
from .tensor.manipulation import unflatten # noqa: F401
from .tensor.math import abs # noqa: F401
from .tensor.math import abs_ # noqa: F401
from .tensor.math import acos # noqa: F401
from .tensor.math import acos_ # noqa: F401
from .tensor.math import asin # noqa: F401
from .tensor.math import asin_ # noqa: F401
from .tensor.math import atan # noqa: F401
from .tensor.math import atan_ # noqa: F401
from .tensor.math import atan2 # noqa: F401
from .tensor.math import ceil # noqa: F401
from .tensor.math import cos # noqa: F401
from .tensor.math import cos_ # noqa: F401
from .tensor.math import tan # noqa: F401
from .tensor.math import tan_ # noqa: F401
from .tensor.math import cosh # noqa: F401
from .tensor.math import cosh_ # noqa: F401
from .tensor.math import cumsum # noqa: F401
from .tensor.math import cummax # noqa: F401
from .tensor.math import cummin # noqa: F401
......@@ -219,6 +226,7 @@ from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401
from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401
from .tensor.math import expm1_ # noqa: F401
from .tensor.math import floor # noqa: F401
from .tensor.math import increment # noqa: F401
from .tensor.math import log # noqa: F401
......@@ -235,9 +243,12 @@ from .tensor.math import rsqrt # noqa: F401
from .tensor.math import scale # noqa: F401
from .tensor.math import sign # noqa: F401
from .tensor.math import sin # noqa: F401
from .tensor.math import sin_ # noqa: F401
from .tensor.math import sinh # noqa: F401
from .tensor.math import sinh_ # noqa: F401
from .tensor.math import sqrt # noqa: F401
from .tensor.math import square # noqa: F401
from .tensor.math import square_ # noqa: F401
from .tensor.math import stanh # noqa: F401
from .tensor.math import sum # noqa: F401
from .tensor.math import nan_to_num # noqa: F401
......@@ -269,7 +280,9 @@ from .tensor.math import logaddexp # noqa: F401
from .tensor.math import inverse # noqa: F401
from .tensor.math import log1p # noqa: F401
from .tensor.math import erf # noqa: F401
from .tensor.math import erf_ # noqa: F401
from .tensor.math import addmm # noqa: F401
from .tensor.math import addmm_ # noqa: F401
from .tensor.math import clip # noqa: F401
from .tensor.math import trace # noqa: F401
from .tensor.math import diagonal # noqa: F401
......@@ -285,8 +298,11 @@ from .tensor.math import digamma # noqa: F401
from .tensor.math import neg # noqa: F401
from .tensor.math import lgamma # noqa: F401
from .tensor.math import acosh # noqa: F401
from .tensor.math import acosh_ # noqa: F401
from .tensor.math import asinh # noqa: F401
from .tensor.math import asinh_ # noqa: F401
from .tensor.math import atanh # noqa: F401
from .tensor.math import atanh_ # noqa: F401
from .tensor.math import lerp # noqa: F401
from .tensor.math import erfinv # noqa: F401
from .tensor.math import rad2deg # noqa: F401
......@@ -431,6 +447,7 @@ __all__ = [ # noqa
'complex64',
'complex128',
'addmm',
'addmm_',
'allclose',
'isclose',
't',
......@@ -468,7 +485,9 @@ __all__ = [ # noqa
'where',
'log1p',
'cos',
'cos_',
'tan',
'tan_',
'mean',
'mode',
'mv',
......@@ -543,6 +562,7 @@ __all__ = [ # noqa
'less_equal',
'triu',
'sin',
'sin_',
'dist',
'cdist',
'unbind',
......@@ -560,6 +580,7 @@ __all__ = [ # noqa
'is_grad_enabled',
'mod',
'abs',
'abs_',
'tril',
'pow',
'pow_',
......@@ -571,12 +592,15 @@ __all__ = [ # noqa
'matmul',
'seed',
'acos',
'acos_',
'logical_xor',
'exp',
'expm1',
'expm1_',
'bernoulli',
'poisson',
'sinh',
'sinh_',
'round',
'DataParallel',
'argmin',
......@@ -590,9 +614,11 @@ __all__ = [ # noqa
'inner',
'outer',
'square',
'square_',
'divide',
'ceil',
'atan',
'atan_',
'atan2',
'rad2deg',
'deg2rad',
......@@ -618,6 +644,7 @@ __all__ = [ # noqa
'dot',
'increment',
'erf',
'erf_',
'bmm',
'chunk',
'tolist',
......
......@@ -141,14 +141,21 @@ from .manipulation import index_put # noqa: F401
from .manipulation import index_put_ # noqa: F401
from .manipulation import unflatten # noqa: F401
from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401
from .math import acos_ # noqa: F401
from .math import asin # noqa: F401
from .math import asin_ # noqa: F401
from .math import atan # noqa: F401
from .math import atan_ # noqa: F401
from .math import ceil # noqa: F401
from .math import ceil_ # noqa: F401
from .math import cos # noqa: F401
from .math import cos_ # noqa: F401
from .math import tan # noqa: F401
from .math import tan_ # noqa: F401
from .math import cosh # noqa: F401
from .math import cosh_ # noqa: F401
from .math import cumsum # noqa: F401
from .math import cummax # noqa: F401
from .math import cummin # noqa: F401
......@@ -175,7 +182,9 @@ from .math import scale # noqa: F401
from .math import scale_ # noqa: F401
from .math import sign # noqa: F401
from .math import sin # noqa: F401
from .math import sin_ # noqa: F401
from .math import sinh # noqa: F401
from .math import sinh_ # noqa: F401
from .math import sqrt # noqa: F401
from .math import sqrt_ # noqa: F401
from .math import square # noqa: F401
......@@ -216,6 +225,7 @@ from .math import log10 # noqa: F401
from .math import log1p # noqa: F401
from .math import erf # noqa: F401
from .math import addmm # noqa: F401
from .math import addmm_ # noqa: F401
from .math import clip # noqa: F401
from .math import clip_ # noqa: F401
from .math import trace # noqa: F401
......@@ -234,8 +244,11 @@ from .math import neg # noqa: F401
from .math import lgamma # noqa: F401
from .math import diagonal # noqa: F401
from .math import acosh # noqa: F401
from .math import acosh_ # noqa: F401
from .math import asinh # noqa: F401
from .math import asinh_ # noqa: F401
from .math import atanh # noqa: F401
from .math import atanh_ # noqa: F401
from .math import lerp # noqa: F401
from .math import lerp_ # noqa: F401
from .math import erfinv # noqa: F401
......@@ -421,6 +434,7 @@ tensor_method_func = [ # noqa
'log1p',
'erf',
'addmm',
'addmm_',
'clip',
'clip_',
'trace',
......
......@@ -14,7 +14,6 @@
import re
import string
import warnings
from io import StringIO
from paddle import _C_ops, _legacy_C_ops
......@@ -352,22 +351,14 @@ def generate_inplace_fn(inplace_op_type):
else:
op = getattr(_legacy_C_ops, inplace_op_type)
return op(x)
else:
warnings.warn(
"In static graph mode, {}() is the same as {}() and does not perform inplace operation.".format(
inplace_op_type, origin_op_type
)
)
return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type
func.__doc__ = """
Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{}`.
Please refer to :ref:`api_paddle_{}`.
""".format(
origin_op_type, origin_op_type
)
return func
......
......@@ -43,20 +43,31 @@ from .creation import _complex_to_real_dtype
from .layer_function_generator import generate_layer_fn, templatedoc
from .manipulation import cast
from .ops import abs # noqa: F401
from .ops import abs_ # noqa: F401
from .ops import acos # noqa: F401
from .ops import acos_ # noqa: F401
from .ops import acosh # noqa: F401
from .ops import acosh_ # noqa: F401
from .ops import asin # noqa: F401
from .ops import asin_ # noqa: F401
from .ops import asinh # noqa: F401
from .ops import asinh_ # noqa: F401
from .ops import atan # noqa: F401
from .ops import atan_ # noqa: F401
from .ops import atanh # noqa: F401
from .ops import atanh_ # noqa: F401
from .ops import ceil # noqa: F401
from .ops import ceil_ # noqa: F401
from .ops import cos # noqa: F401
from .ops import cos_ # noqa: F401
from .ops import cosh # noqa: F401
from .ops import cosh_ # noqa: F401
from .ops import erf # noqa: F401
from .ops import erf_ # noqa: F401
from .ops import exp # noqa: F401
from .ops import exp_ # noqa: F401
from .ops import expm1 # noqa: F401
from .ops import expm1_ # noqa: F401
from .ops import floor # noqa: F401
from .ops import floor_ # noqa: F401
from .ops import reciprocal # noqa: F401
......@@ -68,11 +79,15 @@ from .ops import rsqrt_ # noqa: F401
from .ops import sigmoid # noqa: F401
from .ops import sigmoid_ # noqa: F401
from .ops import sin # noqa: F401
from .ops import sin_ # noqa: F401
from .ops import sinh # noqa: F401
from .ops import sinh_ # noqa: F401
from .ops import sqrt # noqa: F401
from .ops import sqrt_ # noqa: F401
from .ops import square # noqa: F401
from .ops import square_ # noqa: F401
from .ops import tan # noqa: F401
from .ops import tan_ # noqa: F401
__all__ = []
......@@ -482,12 +497,8 @@ def pow_(x, y, name=None):
"""
if isinstance(y, (int, float)):
return _C_ops.pow_(x, y)
elif isinstance(y, (paddle.Tensor, Variable)):
return _C_ops.elementwise_pow_(x, y)
else:
raise TypeError(
'y must be scalar or tensor type, but received: %s ' % (type(y))
)
raise TypeError('y must be scalar type, but received: %s ' % (type(y)))
OP_NAMEMAPPING = {
......@@ -2043,6 +2054,66 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
return out
@inplace_apis_in_dygraph_only
def addmm_(input, x, y, beta=1.0, alpha=1.0, name=None):
"""
Inplace version of ``addmm`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_label_addmm`.
"""
input_shape = input.shape
x_shape = x.shape
y_shape = y.shape
if not len(x_shape) == len(y_shape) == 2:
raise ValueError(
"The dimention of x, y should be 2 but receive x's shape: {}, y's shape: {}".format(
x_shape, y_shape
)
)
if x_shape[1] != y_shape[0]:
raise ValueError(
"The input Variable x's width must be equal with Variable y' height. But received x's shape = {}, y's shape = {}.".format(
x_shape, y_shape
)
)
if len(input_shape) == 2:
if input_shape[0] != x_shape[0]:
if input_shape[0] != 1:
raise ValueError(
"When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(
input_shape[0]
)
)
if input_shape[1] != y_shape[1] and input_shape[1] != 1:
raise ValueError(
"When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(
input_shape[1]
)
)
if input_shape[1] != y_shape[1]:
if input_shape[1] != 1:
raise ValueError(
"When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(
input_shape[1]
)
)
elif len(input_shape) == 1:
if input_shape[0] not in (y_shape[1], 1):
raise ValueError(
"The input's shape: {} is not broadcastable with [x.shape[0], y.shape[1]]: [{},{}]".format(
input_shape, x_shape[0], y_shape[1]
)
)
else:
raise ValueError(
"The dimention of input should be 2 or 1 but receive input's shape: {}".format(
input_shape
)
)
if in_dynamic_mode():
return _C_ops.addmm_(input, x, y, beta, alpha)
def renorm(x, p, axis, max_norm):
"""
**renorm**
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
from .. import _C_ops
from ..fluid.data_feeder import check_variable_and_dtype
from ..framework import LayerHelper, in_dynamic_mode
......@@ -47,6 +50,21 @@ __inplace_unary_func__ = [
'round_',
'reciprocal_',
'sigmoid_',
'abs_',
'sin_',
'sinh_',
'asin_',
'asinh_',
'cos_',
'cosh_',
'acos_',
'acosh_',
'tan_',
'atan_',
'atanh_',
'expm1_',
'erf_',
'square_',
]
__all__ = []
......@@ -76,7 +94,9 @@ for _OP in set(__inplace_unary_func__):
_new_OP = _OP
if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP]
_func = generate_inplace_fn(_OP)
func = generate_inplace_fn(_OP)
func.__module__ = __name__
_func = inplace_apis_in_dygraph_only(func)
globals()[_OP] = _func
add_sample_code(
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import unittest
import numpy as np
......@@ -123,6 +124,14 @@ class TestDygraphInplace(unittest.TestCase):
inplace_var[0] = 2.0
np.testing.assert_array_equal(var.numpy(), inplace_var.numpy())
def test_forward_result(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
no_inplace_var = self.non_inplace_api_processing(var)
inplace_var = self.inplace_api_processing(var)
np.testing.assert_array_equal(
no_inplace_var.numpy(), inplace_var.numpy()
)
def test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
......@@ -241,6 +250,52 @@ class TestDygraphInplace(unittest.TestCase):
np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a)
class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
def set_np_compare_func(self):
np_array_equal_with_nan = functools.partial(
np.array_equal, equal_nan=True
)
self.np_compare = np_array_equal_with_nan
def non_inplace_api_processing(self, var):
return paddle.sin(var)
def inplace_api_processing(self, var):
return paddle.sin_(var)
def test_continuous_inplace_backward(self):
# The api that only relies on input to calculate the gradient will copy input before
# the inpalce calculation, so here supports continuous inpalce backward calculation.
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.inplace_api_processing(var_b)
var_d = self.inplace_api_processing(var_c)
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.non_inplace_api_processing(var_b)
var_d = self.non_inplace_api_processing(var_c)
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.unsqueeze(var, -1)
......@@ -506,5 +561,141 @@ class TestGetitemBeforeInplace(unittest.TestCase):
loss.backward()
class TestDygraphInplaceAsin(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.asin(var)
def inplace_api_processing(self, var):
return paddle.asin_(var)
class TestDygraphInplaceSinh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.sinh(var)
def inplace_api_processing(self, var):
return paddle.sinh_(var)
class TestDygraphInplaceAsinh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.asinh(var)
def inplace_api_processing(self, var):
return paddle.asinh_(var)
class TestDygraphInplaceAbs(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.abs(var)
def inplace_api_processing(self, var):
return paddle.abs_(var)
class TestDygraphInplaceCos(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.cos(var)
def inplace_api_processing(self, var):
return paddle.cos_(var)
class TestDygraphInplaceCosh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.cosh(var)
def inplace_api_processing(self, var):
return paddle.cosh_(var)
class TestDygraphInplaceAcos(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.acos(var)
def inplace_api_processing(self, var):
return paddle.acos_(var)
class TestDygraphInplaceAcosh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.acosh(var)
def inplace_api_processing(self, var):
return paddle.acosh_(var)
class TestDygraphInplaceTan(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.tan(var)
def inplace_api_processing(self, var):
return paddle.tan_(var)
class TestDygraphInplaceATan(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.atan(var)
def inplace_api_processing(self, var):
return paddle.atan_(var)
class TestDygraphInplaceATanh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.atanh(var)
def inplace_api_processing(self, var):
return paddle.atanh_(var)
class TestDygraphInplaceAddMM(TestDygraphInplaceWithContinuous):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 10])
self.dtype = "float32"
self.x = paddle.randn([10, 10], dtype="float32")
self.y = paddle.randn([10, 10], dtype="float32")
def non_inplace_api_processing(self, var):
return paddle.addmm(var, x=self.x, y=self.y)
def inplace_api_processing(self, var):
return paddle.addmm_(var, x=self.x, y=self.y)
def test_errors(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
x1 = paddle.randn([10])
self.assertRaises(ValueError, paddle.addmm_, var, x1, self.y)
y1 = paddle.randn([12, 10])
self.assertRaises(ValueError, paddle.addmm_, var, self.x, y1)
x2 = paddle.randn([12, 10])
self.assertRaises(ValueError, paddle.addmm_, var, x2, self.y)
var1 = paddle.randn([1, 5])
self.assertRaises(ValueError, paddle.addmm_, var1, x2, self.y)
y2 = paddle.randn([10, 12])
self.assertRaises(ValueError, paddle.addmm_, var, self.x, y2)
var2 = paddle.randn([6])
self.assertRaises(ValueError, paddle.addmm_, var2, self.x, self.y)
var3 = paddle.randn([2, 3, 4])
self.assertRaises(ValueError, paddle.addmm_, var3, self.x, self.y)
class TestDygraphInplacePowerScalar(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.pow_(var, 2)
def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)
def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,6 @@
import unittest
import numpy as np
from test_inplace import TestDygraphInplace
import paddle
from paddle.fluid import core
......@@ -214,40 +213,5 @@ class TestPowerError(unittest.TestCase):
self.assertRaises(TypeError, paddle.pow, x, str(y))
class TestInplacePowerScalar(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, 2)
def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)
class TestInplacePowerTensor(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
self.y = paddle.ones([10, 20, 1], dtype="float32") * 2
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, self.y)
def non_inplace_api_processing(self, var):
return paddle.pow(var, self.y)
def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar or tensor type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册