未验证 提交 19302938 编写于 作者: G GGBond8488 提交者: GitHub

batch add inpalce api (#55078)

* batch add inpalce api

* fix inplace fn generate

* add test for  new inpalce api

* fix typro

* fix typro

* fix typro

* fix test error

* fix atan2

* remove atan2

* auto genereate inpalce api

* fix inplace generate fn error

* fix windows error

* fix test error

* fix test error

* fix windows ci error

* fix test error

* fix test_error

* fix test error

* fix eigen aliasing error in inplace

* remove elementwise_pow inplace

* fix doc error

* fix test error
上级 5e6645d7
...@@ -237,7 +237,6 @@ ...@@ -237,7 +237,6 @@
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
func : elementwise_pow func : elementwise_pow
inplace: (x -> out)
backward : elementwise_pow_grad backward : elementwise_pow_grad
- op : embedding - op : embedding
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
kernel : kernel :
func : abs func : abs
data_type : x data_type : x
inplace: (x -> out)
backward : abs_grad backward : abs_grad
- op : accuracy - op : accuracy
...@@ -26,20 +27,22 @@ ...@@ -26,20 +27,22 @@
- op : acos - op : acos
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : acos func : acos
inplace: (x -> out)
backward : acos_grad backward : acos_grad
- op : acosh - op : acosh
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : acosh func : acosh
inplace: (x -> out)
backward : acosh_grad backward : acosh_grad
- op : adagrad_ - op : adagrad_
...@@ -90,12 +93,13 @@ ...@@ -90,12 +93,13 @@
- op : addmm - op : addmm
args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0) args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : AddmmInferMeta func : AddmmInferMeta
kernel : kernel :
func : addmm func : addmm
data_type : x data_type : x
inplace: (input -> out)
backward : addmm_grad backward : addmm_grad
- op : affine_grid - op : affine_grid
...@@ -176,34 +180,37 @@ ...@@ -176,34 +180,37 @@
- op : asin - op : asin
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : asin func : asin
inplace: (x -> out)
backward : asin_grad backward : asin_grad
- op : asinh - op : asinh
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : asinh func : asinh
inplace: (x -> out)
backward : asinh_grad backward : asinh_grad
- op : atan - op : atan
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : atan func : atan
inplace: (x -> out)
backward : atan_grad backward : atan_grad
- op : atan2 - op : atan2
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : Atan2InferMeta func : Atan2InferMeta
kernel : kernel :
...@@ -212,11 +219,12 @@ ...@@ -212,11 +219,12 @@
- op : atanh - op : atanh
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : atanh func : atanh
inplace: (x -> out)
backward : atanh_grad backward : atanh_grad
- op : auc - op : auc
...@@ -524,20 +532,22 @@ ...@@ -524,20 +532,22 @@
- op : cos - op : cos
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : cos func : cos
inplace: (x -> out)
backward : cos_grad backward : cos_grad
- op : cosh - op : cosh
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : cosh func : cosh
inplace: (x -> out)
backward : cosh_grad backward : cosh_grad
- op : crop - op : crop
...@@ -756,11 +766,12 @@ ...@@ -756,11 +766,12 @@
- op : erf - op : erf
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : erf func : erf
inplace : (x -> out)
backward : erf_grad backward : erf_grad
- op : erfinv - op : erfinv
...@@ -806,12 +817,13 @@ ...@@ -806,12 +817,13 @@
- op : expm1 - op : expm1
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : expm1 func : expm1
inplace: (x -> out)
backward : expm1_grad backward : expm1_grad
- op : fft_c2c - op : fft_c2c
...@@ -2250,20 +2262,22 @@ ...@@ -2250,20 +2262,22 @@
- op : sin - op : sin
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : sin func : sin
inplace: (x -> out)
backward : sin_grad backward : sin_grad
- op : sinh - op : sinh
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : sinh func : sinh
inplace: (x -> out)
backward : sinh_grad backward : sinh_grad
- op : slogdet - op : slogdet
...@@ -2409,11 +2423,12 @@ ...@@ -2409,11 +2423,12 @@
- op : tan - op : tan
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : tan func : tan
inplace: (x -> out)
backward : tan_grad backward : tan_grad
- op : tanh - op : tanh
......
...@@ -116,7 +116,10 @@ template <typename T> ...@@ -116,7 +116,10 @@ template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> { struct SinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sine<T>()); // Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
// will give wrong result, details see
// http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
out.device(d) = x.unaryExpr(Sine<T>()).eval();
} }
}; };
...@@ -448,7 +451,7 @@ template <typename T> ...@@ -448,7 +451,7 @@ template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> { struct CosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosine<T>()); out.device(d) = x.unaryExpr(Cosine<T>()).eval();
} }
}; };
...@@ -762,7 +765,10 @@ template <typename T> ...@@ -762,7 +765,10 @@ template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> { struct TanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Tangent<T>()); // Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
// will give wrong result, details see
// http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
out.device(d) = x.unaryExpr(Tangent<T>()).eval();
} }
}; };
...@@ -795,7 +801,7 @@ template <typename T> ...@@ -795,7 +801,7 @@ template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> { struct SinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sinh<T>()); out.device(d) = x.unaryExpr(Sinh<T>()).eval();
} }
}; };
...@@ -804,7 +810,7 @@ template <typename T> ...@@ -804,7 +810,7 @@ template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> { struct CoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosh<T>()); out.device(d) = x.unaryExpr(Cosh<T>()).eval();
} }
}; };
...@@ -855,7 +861,7 @@ template <typename T> ...@@ -855,7 +861,7 @@ template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> { struct AcosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acos<T>()); out.device(d) = x.unaryExpr(Acos<T>()).eval();
} }
}; };
...@@ -892,7 +898,7 @@ template <typename T> ...@@ -892,7 +898,7 @@ template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> { struct AsinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asin<T>()); out.device(d) = x.unaryExpr(Asin<T>()).eval();
} }
}; };
...@@ -929,7 +935,7 @@ template <typename T> ...@@ -929,7 +935,7 @@ template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> { struct AtanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atan<T>()); out.device(d) = x.unaryExpr(Atan<T>()).eval();
} }
}; };
...@@ -977,7 +983,7 @@ template <typename T> ...@@ -977,7 +983,7 @@ template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> { struct AcoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acosh<T>()); out.device(d) = x.unaryExpr(Acosh<T>()).eval();
} }
}; };
...@@ -1014,7 +1020,7 @@ template <typename T> ...@@ -1014,7 +1020,7 @@ template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> { struct AsinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asinh<T>()); out.device(d) = x.unaryExpr(Asinh<T>()).eval();
} }
}; };
...@@ -1051,7 +1057,7 @@ template <typename T> ...@@ -1051,7 +1057,7 @@ template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> { struct AtanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atanh<T>()); out.device(d) = x.unaryExpr(Atanh<T>()).eval();
} }
}; };
......
...@@ -203,14 +203,21 @@ from .tensor.manipulation import index_put # noqa: F401 ...@@ -203,14 +203,21 @@ from .tensor.manipulation import index_put # noqa: F401
from .tensor.manipulation import index_put_ # noqa: F401 from .tensor.manipulation import index_put_ # noqa: F401
from .tensor.manipulation import unflatten # noqa: F401 from .tensor.manipulation import unflatten # noqa: F401
from .tensor.math import abs # noqa: F401 from .tensor.math import abs # noqa: F401
from .tensor.math import abs_ # noqa: F401
from .tensor.math import acos # noqa: F401 from .tensor.math import acos # noqa: F401
from .tensor.math import acos_ # noqa: F401
from .tensor.math import asin # noqa: F401 from .tensor.math import asin # noqa: F401
from .tensor.math import asin_ # noqa: F401
from .tensor.math import atan # noqa: F401 from .tensor.math import atan # noqa: F401
from .tensor.math import atan_ # noqa: F401
from .tensor.math import atan2 # noqa: F401 from .tensor.math import atan2 # noqa: F401
from .tensor.math import ceil # noqa: F401 from .tensor.math import ceil # noqa: F401
from .tensor.math import cos # noqa: F401 from .tensor.math import cos # noqa: F401
from .tensor.math import cos_ # noqa: F401
from .tensor.math import tan # noqa: F401 from .tensor.math import tan # noqa: F401
from .tensor.math import tan_ # noqa: F401
from .tensor.math import cosh # noqa: F401 from .tensor.math import cosh # noqa: F401
from .tensor.math import cosh_ # noqa: F401
from .tensor.math import cumsum # noqa: F401 from .tensor.math import cumsum # noqa: F401
from .tensor.math import cummax # noqa: F401 from .tensor.math import cummax # noqa: F401
from .tensor.math import cummin # noqa: F401 from .tensor.math import cummin # noqa: F401
...@@ -219,6 +226,7 @@ from .tensor.math import logcumsumexp # noqa: F401 ...@@ -219,6 +226,7 @@ from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401 from .tensor.math import logit # noqa: F401
from .tensor.math import exp # noqa: F401 from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401 from .tensor.math import expm1 # noqa: F401
from .tensor.math import expm1_ # noqa: F401
from .tensor.math import floor # noqa: F401 from .tensor.math import floor # noqa: F401
from .tensor.math import increment # noqa: F401 from .tensor.math import increment # noqa: F401
from .tensor.math import log # noqa: F401 from .tensor.math import log # noqa: F401
...@@ -235,9 +243,12 @@ from .tensor.math import rsqrt # noqa: F401 ...@@ -235,9 +243,12 @@ from .tensor.math import rsqrt # noqa: F401
from .tensor.math import scale # noqa: F401 from .tensor.math import scale # noqa: F401
from .tensor.math import sign # noqa: F401 from .tensor.math import sign # noqa: F401
from .tensor.math import sin # noqa: F401 from .tensor.math import sin # noqa: F401
from .tensor.math import sin_ # noqa: F401
from .tensor.math import sinh # noqa: F401 from .tensor.math import sinh # noqa: F401
from .tensor.math import sinh_ # noqa: F401
from .tensor.math import sqrt # noqa: F401 from .tensor.math import sqrt # noqa: F401
from .tensor.math import square # noqa: F401 from .tensor.math import square # noqa: F401
from .tensor.math import square_ # noqa: F401
from .tensor.math import stanh # noqa: F401 from .tensor.math import stanh # noqa: F401
from .tensor.math import sum # noqa: F401 from .tensor.math import sum # noqa: F401
from .tensor.math import nan_to_num # noqa: F401 from .tensor.math import nan_to_num # noqa: F401
...@@ -269,7 +280,9 @@ from .tensor.math import logaddexp # noqa: F401 ...@@ -269,7 +280,9 @@ from .tensor.math import logaddexp # noqa: F401
from .tensor.math import inverse # noqa: F401 from .tensor.math import inverse # noqa: F401
from .tensor.math import log1p # noqa: F401 from .tensor.math import log1p # noqa: F401
from .tensor.math import erf # noqa: F401 from .tensor.math import erf # noqa: F401
from .tensor.math import erf_ # noqa: F401
from .tensor.math import addmm # noqa: F401 from .tensor.math import addmm # noqa: F401
from .tensor.math import addmm_ # noqa: F401
from .tensor.math import clip # noqa: F401 from .tensor.math import clip # noqa: F401
from .tensor.math import trace # noqa: F401 from .tensor.math import trace # noqa: F401
from .tensor.math import diagonal # noqa: F401 from .tensor.math import diagonal # noqa: F401
...@@ -285,8 +298,11 @@ from .tensor.math import digamma # noqa: F401 ...@@ -285,8 +298,11 @@ from .tensor.math import digamma # noqa: F401
from .tensor.math import neg # noqa: F401 from .tensor.math import neg # noqa: F401
from .tensor.math import lgamma # noqa: F401 from .tensor.math import lgamma # noqa: F401
from .tensor.math import acosh # noqa: F401 from .tensor.math import acosh # noqa: F401
from .tensor.math import acosh_ # noqa: F401
from .tensor.math import asinh # noqa: F401 from .tensor.math import asinh # noqa: F401
from .tensor.math import asinh_ # noqa: F401
from .tensor.math import atanh # noqa: F401 from .tensor.math import atanh # noqa: F401
from .tensor.math import atanh_ # noqa: F401
from .tensor.math import lerp # noqa: F401 from .tensor.math import lerp # noqa: F401
from .tensor.math import erfinv # noqa: F401 from .tensor.math import erfinv # noqa: F401
from .tensor.math import rad2deg # noqa: F401 from .tensor.math import rad2deg # noqa: F401
...@@ -431,6 +447,7 @@ __all__ = [ # noqa ...@@ -431,6 +447,7 @@ __all__ = [ # noqa
'complex64', 'complex64',
'complex128', 'complex128',
'addmm', 'addmm',
'addmm_',
'allclose', 'allclose',
'isclose', 'isclose',
't', 't',
...@@ -468,7 +485,9 @@ __all__ = [ # noqa ...@@ -468,7 +485,9 @@ __all__ = [ # noqa
'where', 'where',
'log1p', 'log1p',
'cos', 'cos',
'cos_',
'tan', 'tan',
'tan_',
'mean', 'mean',
'mode', 'mode',
'mv', 'mv',
...@@ -543,6 +562,7 @@ __all__ = [ # noqa ...@@ -543,6 +562,7 @@ __all__ = [ # noqa
'less_equal', 'less_equal',
'triu', 'triu',
'sin', 'sin',
'sin_',
'dist', 'dist',
'cdist', 'cdist',
'unbind', 'unbind',
...@@ -560,6 +580,7 @@ __all__ = [ # noqa ...@@ -560,6 +580,7 @@ __all__ = [ # noqa
'is_grad_enabled', 'is_grad_enabled',
'mod', 'mod',
'abs', 'abs',
'abs_',
'tril', 'tril',
'pow', 'pow',
'pow_', 'pow_',
...@@ -571,12 +592,15 @@ __all__ = [ # noqa ...@@ -571,12 +592,15 @@ __all__ = [ # noqa
'matmul', 'matmul',
'seed', 'seed',
'acos', 'acos',
'acos_',
'logical_xor', 'logical_xor',
'exp', 'exp',
'expm1', 'expm1',
'expm1_',
'bernoulli', 'bernoulli',
'poisson', 'poisson',
'sinh', 'sinh',
'sinh_',
'round', 'round',
'DataParallel', 'DataParallel',
'argmin', 'argmin',
...@@ -590,9 +614,11 @@ __all__ = [ # noqa ...@@ -590,9 +614,11 @@ __all__ = [ # noqa
'inner', 'inner',
'outer', 'outer',
'square', 'square',
'square_',
'divide', 'divide',
'ceil', 'ceil',
'atan', 'atan',
'atan_',
'atan2', 'atan2',
'rad2deg', 'rad2deg',
'deg2rad', 'deg2rad',
...@@ -618,6 +644,7 @@ __all__ = [ # noqa ...@@ -618,6 +644,7 @@ __all__ = [ # noqa
'dot', 'dot',
'increment', 'increment',
'erf', 'erf',
'erf_',
'bmm', 'bmm',
'chunk', 'chunk',
'tolist', 'tolist',
......
...@@ -141,14 +141,21 @@ from .manipulation import index_put # noqa: F401 ...@@ -141,14 +141,21 @@ from .manipulation import index_put # noqa: F401
from .manipulation import index_put_ # noqa: F401 from .manipulation import index_put_ # noqa: F401
from .manipulation import unflatten # noqa: F401 from .manipulation import unflatten # noqa: F401
from .math import abs # noqa: F401 from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401 from .math import acos # noqa: F401
from .math import acos_ # noqa: F401
from .math import asin # noqa: F401 from .math import asin # noqa: F401
from .math import asin_ # noqa: F401
from .math import atan # noqa: F401 from .math import atan # noqa: F401
from .math import atan_ # noqa: F401
from .math import ceil # noqa: F401 from .math import ceil # noqa: F401
from .math import ceil_ # noqa: F401 from .math import ceil_ # noqa: F401
from .math import cos # noqa: F401 from .math import cos # noqa: F401
from .math import cos_ # noqa: F401
from .math import tan # noqa: F401 from .math import tan # noqa: F401
from .math import tan_ # noqa: F401
from .math import cosh # noqa: F401 from .math import cosh # noqa: F401
from .math import cosh_ # noqa: F401
from .math import cumsum # noqa: F401 from .math import cumsum # noqa: F401
from .math import cummax # noqa: F401 from .math import cummax # noqa: F401
from .math import cummin # noqa: F401 from .math import cummin # noqa: F401
...@@ -175,7 +182,9 @@ from .math import scale # noqa: F401 ...@@ -175,7 +182,9 @@ from .math import scale # noqa: F401
from .math import scale_ # noqa: F401 from .math import scale_ # noqa: F401
from .math import sign # noqa: F401 from .math import sign # noqa: F401
from .math import sin # noqa: F401 from .math import sin # noqa: F401
from .math import sin_ # noqa: F401
from .math import sinh # noqa: F401 from .math import sinh # noqa: F401
from .math import sinh_ # noqa: F401
from .math import sqrt # noqa: F401 from .math import sqrt # noqa: F401
from .math import sqrt_ # noqa: F401 from .math import sqrt_ # noqa: F401
from .math import square # noqa: F401 from .math import square # noqa: F401
...@@ -216,6 +225,7 @@ from .math import log10 # noqa: F401 ...@@ -216,6 +225,7 @@ from .math import log10 # noqa: F401
from .math import log1p # noqa: F401 from .math import log1p # noqa: F401
from .math import erf # noqa: F401 from .math import erf # noqa: F401
from .math import addmm # noqa: F401 from .math import addmm # noqa: F401
from .math import addmm_ # noqa: F401
from .math import clip # noqa: F401 from .math import clip # noqa: F401
from .math import clip_ # noqa: F401 from .math import clip_ # noqa: F401
from .math import trace # noqa: F401 from .math import trace # noqa: F401
...@@ -234,8 +244,11 @@ from .math import neg # noqa: F401 ...@@ -234,8 +244,11 @@ from .math import neg # noqa: F401
from .math import lgamma # noqa: F401 from .math import lgamma # noqa: F401
from .math import diagonal # noqa: F401 from .math import diagonal # noqa: F401
from .math import acosh # noqa: F401 from .math import acosh # noqa: F401
from .math import acosh_ # noqa: F401
from .math import asinh # noqa: F401 from .math import asinh # noqa: F401
from .math import asinh_ # noqa: F401
from .math import atanh # noqa: F401 from .math import atanh # noqa: F401
from .math import atanh_ # noqa: F401
from .math import lerp # noqa: F401 from .math import lerp # noqa: F401
from .math import lerp_ # noqa: F401 from .math import lerp_ # noqa: F401
from .math import erfinv # noqa: F401 from .math import erfinv # noqa: F401
...@@ -421,6 +434,7 @@ tensor_method_func = [ # noqa ...@@ -421,6 +434,7 @@ tensor_method_func = [ # noqa
'log1p', 'log1p',
'erf', 'erf',
'addmm', 'addmm',
'addmm_',
'clip', 'clip',
'clip_', 'clip_',
'trace', 'trace',
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import re import re
import string import string
import warnings
from io import StringIO from io import StringIO
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
...@@ -352,22 +351,14 @@ def generate_inplace_fn(inplace_op_type): ...@@ -352,22 +351,14 @@ def generate_inplace_fn(inplace_op_type):
else: else:
op = getattr(_legacy_C_ops, inplace_op_type) op = getattr(_legacy_C_ops, inplace_op_type)
return op(x) return op(x)
else:
warnings.warn(
"In static graph mode, {}() is the same as {}() and does not perform inplace operation.".format(
inplace_op_type, origin_op_type
)
)
return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type func.__name__ = inplace_op_type
func.__doc__ = """ func.__doc__ = """
Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{}`. Please refer to :ref:`api_paddle_{}`.
""".format( """.format(
origin_op_type, origin_op_type origin_op_type, origin_op_type
) )
return func return func
......
...@@ -43,20 +43,31 @@ from .creation import _complex_to_real_dtype ...@@ -43,20 +43,31 @@ from .creation import _complex_to_real_dtype
from .layer_function_generator import generate_layer_fn, templatedoc from .layer_function_generator import generate_layer_fn, templatedoc
from .manipulation import cast from .manipulation import cast
from .ops import abs # noqa: F401 from .ops import abs # noqa: F401
from .ops import abs_ # noqa: F401
from .ops import acos # noqa: F401 from .ops import acos # noqa: F401
from .ops import acos_ # noqa: F401
from .ops import acosh # noqa: F401 from .ops import acosh # noqa: F401
from .ops import acosh_ # noqa: F401
from .ops import asin # noqa: F401 from .ops import asin # noqa: F401
from .ops import asin_ # noqa: F401
from .ops import asinh # noqa: F401 from .ops import asinh # noqa: F401
from .ops import asinh_ # noqa: F401
from .ops import atan # noqa: F401 from .ops import atan # noqa: F401
from .ops import atan_ # noqa: F401
from .ops import atanh # noqa: F401 from .ops import atanh # noqa: F401
from .ops import atanh_ # noqa: F401
from .ops import ceil # noqa: F401 from .ops import ceil # noqa: F401
from .ops import ceil_ # noqa: F401 from .ops import ceil_ # noqa: F401
from .ops import cos # noqa: F401 from .ops import cos # noqa: F401
from .ops import cos_ # noqa: F401
from .ops import cosh # noqa: F401 from .ops import cosh # noqa: F401
from .ops import cosh_ # noqa: F401
from .ops import erf # noqa: F401 from .ops import erf # noqa: F401
from .ops import erf_ # noqa: F401
from .ops import exp # noqa: F401 from .ops import exp # noqa: F401
from .ops import exp_ # noqa: F401 from .ops import exp_ # noqa: F401
from .ops import expm1 # noqa: F401 from .ops import expm1 # noqa: F401
from .ops import expm1_ # noqa: F401
from .ops import floor # noqa: F401 from .ops import floor # noqa: F401
from .ops import floor_ # noqa: F401 from .ops import floor_ # noqa: F401
from .ops import reciprocal # noqa: F401 from .ops import reciprocal # noqa: F401
...@@ -68,11 +79,15 @@ from .ops import rsqrt_ # noqa: F401 ...@@ -68,11 +79,15 @@ from .ops import rsqrt_ # noqa: F401
from .ops import sigmoid # noqa: F401 from .ops import sigmoid # noqa: F401
from .ops import sigmoid_ # noqa: F401 from .ops import sigmoid_ # noqa: F401
from .ops import sin # noqa: F401 from .ops import sin # noqa: F401
from .ops import sin_ # noqa: F401
from .ops import sinh # noqa: F401 from .ops import sinh # noqa: F401
from .ops import sinh_ # noqa: F401
from .ops import sqrt # noqa: F401 from .ops import sqrt # noqa: F401
from .ops import sqrt_ # noqa: F401 from .ops import sqrt_ # noqa: F401
from .ops import square # noqa: F401 from .ops import square # noqa: F401
from .ops import square_ # noqa: F401
from .ops import tan # noqa: F401 from .ops import tan # noqa: F401
from .ops import tan_ # noqa: F401
__all__ = [] __all__ = []
...@@ -482,12 +497,8 @@ def pow_(x, y, name=None): ...@@ -482,12 +497,8 @@ def pow_(x, y, name=None):
""" """
if isinstance(y, (int, float)): if isinstance(y, (int, float)):
return _C_ops.pow_(x, y) return _C_ops.pow_(x, y)
elif isinstance(y, (paddle.Tensor, Variable)):
return _C_ops.elementwise_pow_(x, y)
else: else:
raise TypeError( raise TypeError('y must be scalar type, but received: %s ' % (type(y)))
'y must be scalar or tensor type, but received: %s ' % (type(y))
)
OP_NAMEMAPPING = { OP_NAMEMAPPING = {
...@@ -2043,6 +2054,66 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): ...@@ -2043,6 +2054,66 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
return out return out
@inplace_apis_in_dygraph_only
def addmm_(input, x, y, beta=1.0, alpha=1.0, name=None):
"""
Inplace version of ``addmm`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_label_addmm`.
"""
input_shape = input.shape
x_shape = x.shape
y_shape = y.shape
if not len(x_shape) == len(y_shape) == 2:
raise ValueError(
"The dimention of x, y should be 2 but receive x's shape: {}, y's shape: {}".format(
x_shape, y_shape
)
)
if x_shape[1] != y_shape[0]:
raise ValueError(
"The input Variable x's width must be equal with Variable y' height. But received x's shape = {}, y's shape = {}.".format(
x_shape, y_shape
)
)
if len(input_shape) == 2:
if input_shape[0] != x_shape[0]:
if input_shape[0] != 1:
raise ValueError(
"When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(
input_shape[0]
)
)
if input_shape[1] != y_shape[1] and input_shape[1] != 1:
raise ValueError(
"When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(
input_shape[1]
)
)
if input_shape[1] != y_shape[1]:
if input_shape[1] != 1:
raise ValueError(
"When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(
input_shape[1]
)
)
elif len(input_shape) == 1:
if input_shape[0] not in (y_shape[1], 1):
raise ValueError(
"The input's shape: {} is not broadcastable with [x.shape[0], y.shape[1]]: [{},{}]".format(
input_shape, x_shape[0], y_shape[1]
)
)
else:
raise ValueError(
"The dimention of input should be 2 or 1 but receive input's shape: {}".format(
input_shape
)
)
if in_dynamic_mode():
return _C_ops.addmm_(input, x, y, beta, alpha)
def renorm(x, p, axis, max_norm): def renorm(x, p, axis, max_norm):
""" """
**renorm** **renorm**
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
from .. import _C_ops from .. import _C_ops
from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.data_feeder import check_variable_and_dtype
from ..framework import LayerHelper, in_dynamic_mode from ..framework import LayerHelper, in_dynamic_mode
...@@ -47,6 +50,21 @@ __inplace_unary_func__ = [ ...@@ -47,6 +50,21 @@ __inplace_unary_func__ = [
'round_', 'round_',
'reciprocal_', 'reciprocal_',
'sigmoid_', 'sigmoid_',
'abs_',
'sin_',
'sinh_',
'asin_',
'asinh_',
'cos_',
'cosh_',
'acos_',
'acosh_',
'tan_',
'atan_',
'atanh_',
'expm1_',
'erf_',
'square_',
] ]
__all__ = [] __all__ = []
...@@ -76,7 +94,9 @@ for _OP in set(__inplace_unary_func__): ...@@ -76,7 +94,9 @@ for _OP in set(__inplace_unary_func__):
_new_OP = _OP _new_OP = _OP
if _OP in __deprecated_func_name__: if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP] _new_OP = __deprecated_func_name__[_OP]
_func = generate_inplace_fn(_OP) func = generate_inplace_fn(_OP)
func.__module__ = __name__
_func = inplace_apis_in_dygraph_only(func)
globals()[_OP] = _func globals()[_OP] = _func
add_sample_code( add_sample_code(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import unittest import unittest
import numpy as np import numpy as np
...@@ -123,6 +124,14 @@ class TestDygraphInplace(unittest.TestCase): ...@@ -123,6 +124,14 @@ class TestDygraphInplace(unittest.TestCase):
inplace_var[0] = 2.0 inplace_var[0] = 2.0
np.testing.assert_array_equal(var.numpy(), inplace_var.numpy()) np.testing.assert_array_equal(var.numpy(), inplace_var.numpy())
def test_forward_result(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
no_inplace_var = self.non_inplace_api_processing(var)
inplace_var = self.inplace_api_processing(var)
np.testing.assert_array_equal(
no_inplace_var.numpy(), inplace_var.numpy()
)
def test_forward_version(self): def test_forward_version(self):
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
...@@ -241,6 +250,52 @@ class TestDygraphInplace(unittest.TestCase): ...@@ -241,6 +250,52 @@ class TestDygraphInplace(unittest.TestCase):
np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a) np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a)
class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
def set_np_compare_func(self):
np_array_equal_with_nan = functools.partial(
np.array_equal, equal_nan=True
)
self.np_compare = np_array_equal_with_nan
def non_inplace_api_processing(self, var):
return paddle.sin(var)
def inplace_api_processing(self, var):
return paddle.sin_(var)
def test_continuous_inplace_backward(self):
# The api that only relies on input to calculate the gradient will copy input before
# the inpalce calculation, so here supports continuous inpalce backward calculation.
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.inplace_api_processing(var_b)
var_d = self.inplace_api_processing(var_c)
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.non_inplace_api_processing(var_b)
var_d = self.non_inplace_api_processing(var_c)
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
class TestDygraphInplaceUnsqueeze(TestDygraphInplace): class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
def non_inplace_api_processing(self, var): def non_inplace_api_processing(self, var):
return paddle.unsqueeze(var, -1) return paddle.unsqueeze(var, -1)
...@@ -506,5 +561,141 @@ class TestGetitemBeforeInplace(unittest.TestCase): ...@@ -506,5 +561,141 @@ class TestGetitemBeforeInplace(unittest.TestCase):
loss.backward() loss.backward()
class TestDygraphInplaceAsin(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.asin(var)
def inplace_api_processing(self, var):
return paddle.asin_(var)
class TestDygraphInplaceSinh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.sinh(var)
def inplace_api_processing(self, var):
return paddle.sinh_(var)
class TestDygraphInplaceAsinh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.asinh(var)
def inplace_api_processing(self, var):
return paddle.asinh_(var)
class TestDygraphInplaceAbs(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.abs(var)
def inplace_api_processing(self, var):
return paddle.abs_(var)
class TestDygraphInplaceCos(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.cos(var)
def inplace_api_processing(self, var):
return paddle.cos_(var)
class TestDygraphInplaceCosh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.cosh(var)
def inplace_api_processing(self, var):
return paddle.cosh_(var)
class TestDygraphInplaceAcos(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.acos(var)
def inplace_api_processing(self, var):
return paddle.acos_(var)
class TestDygraphInplaceAcosh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.acosh(var)
def inplace_api_processing(self, var):
return paddle.acosh_(var)
class TestDygraphInplaceTan(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.tan(var)
def inplace_api_processing(self, var):
return paddle.tan_(var)
class TestDygraphInplaceATan(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.atan(var)
def inplace_api_processing(self, var):
return paddle.atan_(var)
class TestDygraphInplaceATanh(TestDygraphInplaceWithContinuous):
def non_inplace_api_processing(self, var):
return paddle.atanh(var)
def inplace_api_processing(self, var):
return paddle.atanh_(var)
class TestDygraphInplaceAddMM(TestDygraphInplaceWithContinuous):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 10])
self.dtype = "float32"
self.x = paddle.randn([10, 10], dtype="float32")
self.y = paddle.randn([10, 10], dtype="float32")
def non_inplace_api_processing(self, var):
return paddle.addmm(var, x=self.x, y=self.y)
def inplace_api_processing(self, var):
return paddle.addmm_(var, x=self.x, y=self.y)
def test_errors(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
x1 = paddle.randn([10])
self.assertRaises(ValueError, paddle.addmm_, var, x1, self.y)
y1 = paddle.randn([12, 10])
self.assertRaises(ValueError, paddle.addmm_, var, self.x, y1)
x2 = paddle.randn([12, 10])
self.assertRaises(ValueError, paddle.addmm_, var, x2, self.y)
var1 = paddle.randn([1, 5])
self.assertRaises(ValueError, paddle.addmm_, var1, x2, self.y)
y2 = paddle.randn([10, 12])
self.assertRaises(ValueError, paddle.addmm_, var, self.x, y2)
var2 = paddle.randn([6])
self.assertRaises(ValueError, paddle.addmm_, var2, self.x, self.y)
var3 = paddle.randn([2, 3, 4])
self.assertRaises(ValueError, paddle.addmm_, var3, self.x, self.y)
class TestDygraphInplacePowerScalar(TestDygraphInplaceWithContinuous):
def inplace_api_processing(self, var):
return paddle.pow_(var, 2)
def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)
def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import unittest import unittest
import numpy as np import numpy as np
from test_inplace import TestDygraphInplace
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -214,40 +213,5 @@ class TestPowerError(unittest.TestCase): ...@@ -214,40 +213,5 @@ class TestPowerError(unittest.TestCase):
self.assertRaises(TypeError, paddle.pow, x, str(y)) self.assertRaises(TypeError, paddle.pow, x, str(y))
class TestInplacePowerScalar(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, 2)
def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)
class TestInplacePowerTensor(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
self.y = paddle.ones([10, 20, 1], dtype="float32") * 2
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, self.y)
def non_inplace_api_processing(self, var):
return paddle.pow(var, self.y)
def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar or tensor type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册