未验证 提交 36b7368d 编写于 作者: X xiaoting 提交者: GitHub

Add arc hyperbolic function op (#37076)

* add activation

* update activation_op

* add unitest for activation

* fix acosh for init, test=develop
上级 e64f0997
...@@ -284,6 +284,27 @@ $$out = cosh(x)$$ ...@@ -284,6 +284,27 @@ $$out = cosh(x)$$
)DOC"; )DOC";
UNUSED constexpr char AsinhDoc[] = R"DOC(
Asinh Activation Operator.
$$out = asinh(x)$$
)DOC";
UNUSED constexpr char AcoshDoc[] = R"DOC(
Acosh Activation Operator.
$$out = acosh(x)$$
)DOC";
UNUSED constexpr char AtanhDoc[] = R"DOC(
Atanh Activation Operator.
$$out = atanh(x)$$
)DOC";
UNUSED constexpr char RoundDoc[] = R"DOC( UNUSED constexpr char RoundDoc[] = R"DOC(
The OP rounds the values in the input to the nearest integer value. The OP rounds the values in the input to the nearest integer value.
...@@ -832,6 +853,9 @@ REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc); ...@@ -832,6 +853,9 @@ REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc); REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc); REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc); REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Acosh, AcoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Asinh, AsinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Atanh, AtanhDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc); REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc); REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
......
...@@ -473,6 +473,85 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -473,6 +473,85 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T>
struct CudaAcoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// Acosh(x) = acosh(x)
__device__ __forceinline__ T operator()(const T& arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(acosh(x));
}
};
template <typename T>
struct CudaAcoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1 / sqrt(x^2 - 1)
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / sqrt(x * x - one));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAsinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// Asinh(x) = asinh(x)
__device__ __forceinline__ T operator()(const T& arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(asinh(x));
}
};
template <typename T>
struct CudaAsinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1/sqrt(x^2 + 1)
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / sqrt(x * x + one));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// Atanh(x) = atanh(x)
__device__ __forceinline__ T operator()(const T& arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(atanh(x));
}
};
template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1/(1- x^2)
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / (one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T> template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> { struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
...@@ -1707,6 +1786,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -1707,6 +1786,9 @@ REGISTER_OP_CUDA_KERNEL(
__macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); \ __macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); \
__macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); \ __macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); \
__macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); \ __macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); \
__macro(asinh, Asinh, CudaAsinhFunctor, CudaAsinhGradFunctor); \
__macro(acosh, Acosh, CudaAcoshFunctor, CudaAcoshGradFunctor); \
__macro(atanh, Atanh, CudaAtanhFunctor, CudaAtanhGradFunctor); \
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \ CudaReciprocalGradFunctor); \
......
...@@ -1020,6 +1020,107 @@ struct AtanGradFunctor : public BaseActivationFunctor<T> { ...@@ -1020,6 +1020,107 @@ struct AtanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T>
struct Acosh {
HOSTDEVICE T operator()(const T& val) const { return acosh(val); }
};
template <>
struct Acosh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(acosh(static_cast<float>(val)));
}
};
// Acosh(x) = acosh(x)
template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acosh<T>());
}
};
// acosh'(x) = 1/sqrt(x^2 - 1)
template <typename T>
struct AcoshGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (x * x - static_cast<T>(1)).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Asinh {
HOSTDEVICE T operator()(const T& val) const { return asinh(val); }
};
template <>
struct Asinh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(asinh(static_cast<float>(val)));
}
};
// Asinh(x) = asinh(x)
template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asinh<T>());
}
};
// asinh'(x) = 1/sqrt(x^2 + 1)
template <typename T>
struct AsinhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (x.square() + static_cast<T>(1)).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Atanh {
HOSTDEVICE T operator()(const T& val) const { return atanh(val); }
};
template <>
struct Atanh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(atanh(static_cast<float>(val)));
}
};
// Atanh(x) = atanh(x)
template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atanh<T>());
}
};
// atanh'(x) = 1/(1 - x^2)
template <typename T>
struct AtanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) - x.square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// round(x) = [x] // round(x) = [x]
template <typename T> template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> { struct RoundFunctor : public BaseActivationFunctor<T> {
...@@ -2719,6 +2820,9 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2719,6 +2820,9 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
__macro(sinh, Sinh, SinhFunctor, SinhGradFunctor); \ __macro(sinh, Sinh, SinhFunctor, SinhGradFunctor); \
__macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \ __macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \
__macro(asinh, Asinh, AsinhFunctor, AsinhGradFunctor); \
__macro(acosh, Acosh, AcoshFunctor, AcoshGradFunctor); \
__macro(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
......
...@@ -228,6 +228,9 @@ from .tensor.math import trunc # noqa: F401 ...@@ -228,6 +228,9 @@ from .tensor.math import trunc # noqa: F401
from .tensor.math import digamma # noqa: F401 from .tensor.math import digamma # noqa: F401
from .tensor.math import neg # noqa: F401 from .tensor.math import neg # noqa: F401
from .tensor.math import lgamma # noqa: F401 from .tensor.math import lgamma # noqa: F401
from .tensor.math import acosh # noqa: F401
from .tensor.math import asinh # noqa: F401
from .tensor.math import atanh # noqa: F401
from .tensor.math import lerp # noqa: F401 from .tensor.math import lerp # noqa: F401
from .tensor.math import rad2deg # noqa: F401 from .tensor.math import rad2deg # noqa: F401
from .tensor.math import deg2rad # noqa: F401 from .tensor.math import deg2rad # noqa: F401
...@@ -566,6 +569,9 @@ __all__ = [ # noqa ...@@ -566,6 +569,9 @@ __all__ = [ # noqa
'einsum', 'einsum',
'set_flags', 'set_flags',
'get_flags', 'get_flags',
'asinh',
'acosh',
'atanh',
'as_complex', 'as_complex',
'as_real', 'as_real',
'diff', 'diff',
......
...@@ -55,6 +55,9 @@ __unary_func__ = [ ...@@ -55,6 +55,9 @@ __unary_func__ = [
'reciprocal', 'reciprocal',
'square', 'square',
'lgamma', 'lgamma',
'acosh',
'asinh',
'atanh',
] ]
__inplace_unary_func__ = [ __inplace_unary_func__ = [
...@@ -372,6 +375,45 @@ Examples: ...@@ -372,6 +375,45 @@ Examples:
""") """)
add_sample_code(globals()["asinh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asinh(x)
print(out)
# [-0.39003533, -0.19869010, 0.09983408, 0.29567307]
""")
add_sample_code(globals()["acosh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1., 3., 4., 5.])
out = paddle.acosh(x)
print(out)
# [0. , 1.76274729, 2.06343699, 2.29243159]
""")
add_sample_code(globals()["atanh"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atanh(x)
print(out)
# [-0.42364895, -0.20273256, 0.10033535, 0.30951962]
""")
add_sample_code(globals()["round"], r""" add_sample_code(globals()["round"], r"""
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -1145,6 +1145,60 @@ class TestAsin(TestActivation): ...@@ -1145,6 +1145,60 @@ class TestAsin(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestAcosh(TestActivation):
def setUp(self):
self.op_type = "acosh"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(2, 3, [10, 12]).astype(self.dtype)
out = np.arccosh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestAsinh(TestActivation):
def setUp(self):
self.op_type = "asinh"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(1, 2, [10, 12]).astype(self.dtype)
out = np.arcsinh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestAtanh(TestActivation):
def setUp(self):
self.op_type = "atanh"
self.init_dtype()
np.random.seed(400)
x = np.random.uniform(-0.9, 0.9, [10, 12]).astype(self.dtype)
out = np.arctanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestRound(TestActivation): class TestRound(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "round" self.op_type = "round"
...@@ -2815,6 +2869,9 @@ create_test_error_class('sin') ...@@ -2815,6 +2869,9 @@ create_test_error_class('sin')
create_test_error_class('sqrt') create_test_error_class('sqrt')
create_test_error_class('tanh') create_test_error_class('tanh')
create_test_error_class('tan') create_test_error_class('tan')
create_test_error_class('acosh')
create_test_error_class('asinh')
create_test_error_class('atanh')
#------------------ Test Cudnn Activation---------------------- #------------------ Test Cudnn Activation----------------------
...@@ -2886,6 +2943,9 @@ create_test_act_fp16_class(TestSin) ...@@ -2886,6 +2943,9 @@ create_test_act_fp16_class(TestSin)
create_test_act_fp16_class(TestSinh) create_test_act_fp16_class(TestSinh)
create_test_act_fp16_class(TestAsin) create_test_act_fp16_class(TestAsin)
create_test_act_fp16_class(TestAtan) create_test_act_fp16_class(TestAtan)
create_test_act_fp16_class(TestAcosh, grad_atol=0.85)
create_test_act_fp16_class(TestAsinh, grad_atol=0.85)
create_test_act_fp16_class(TestAtanh, grad_atol=0.85)
create_test_act_fp16_class(TestRound, grad_check=False) create_test_act_fp16_class(TestRound, grad_check=False)
create_test_act_fp16_class(TestRelu) create_test_act_fp16_class(TestRelu)
create_test_act_fp16_class(TestGelu) create_test_act_fp16_class(TestGelu)
......
...@@ -371,6 +371,15 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -371,6 +371,15 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
np.array_equal(x.rank().numpy(), paddle.rank(x).numpy())) np.array_equal(x.rank().numpy(), paddle.rank(x).numpy()))
self.assertTrue( self.assertTrue(
np.array_equal(x[0].t().numpy(), paddle.t(x[0]).numpy())) np.array_equal(x[0].t().numpy(), paddle.t(x[0]).numpy()))
self.assertTrue(
np.array_equal(x.asinh().numpy(), paddle.asinh(x).numpy()))
### acosh(x) = nan, need to change input
t_np = np.random.uniform(1, 2, [2, 3]).astype(self.dtype)
t = paddle.to_tensor(t_np)
self.assertTrue(
np.array_equal(t.acosh().numpy(), paddle.acosh(t).numpy()))
self.assertTrue(
np.array_equal(x.atanh().numpy(), paddle.atanh(x).numpy()))
d = paddle.to_tensor([[1.2285208, 1.3491015, 1.4899898], d = paddle.to_tensor([[1.2285208, 1.3491015, 1.4899898],
[1.30058, 1.0688717, 1.4928783], [1.30058, 1.0688717, 1.4928783],
[1.0958099, 1.3724753, 1.8926544]]) [1.0958099, 1.3724753, 1.8926544]])
......
...@@ -194,6 +194,9 @@ from .math import digamma # noqa: F401 ...@@ -194,6 +194,9 @@ from .math import digamma # noqa: F401
from .math import neg # noqa: F401 from .math import neg # noqa: F401
from .math import lgamma # noqa: F401 from .math import lgamma # noqa: F401
from .math import diagonal # noqa: F401 from .math import diagonal # noqa: F401
from .math import acosh # noqa: F401
from .math import asinh # noqa: F401
from .math import atanh # noqa: F401
from .math import lerp # noqa: F401 from .math import lerp # noqa: F401
from .math import lerp_ # noqa: F401 from .math import lerp_ # noqa: F401
from .math import rad2deg # noqa: F401 from .math import rad2deg # noqa: F401
...@@ -420,6 +423,9 @@ tensor_method_func = [ #noqa ...@@ -420,6 +423,9 @@ tensor_method_func = [ #noqa
'multi_dot', 'multi_dot',
'solve', 'solve',
'triangular_solve', 'triangular_solve',
'asinh',
'atanh',
'acosh',
'as_complex', 'as_complex',
'as_real', 'as_real',
'rad2deg', 'rad2deg',
......
...@@ -65,6 +65,9 @@ from ..fluid.layers import sqrt # noqa: F401 ...@@ -65,6 +65,9 @@ from ..fluid.layers import sqrt # noqa: F401
from ..fluid.layers import sqrt_ # noqa: F401 from ..fluid.layers import sqrt_ # noqa: F401
from ..fluid.layers import sin # noqa: F401 from ..fluid.layers import sin # noqa: F401
from ..fluid.layers import lgamma # noqa: F401 from ..fluid.layers import lgamma # noqa: F401
from ..fluid.layers import asinh # noqa: F401
from ..fluid.layers import acosh # noqa: F401
from ..fluid.layers import atanh # noqa: F401
from ..fluid.layers import multiplex # noqa: F401 from ..fluid.layers import multiplex # noqa: F401
from ..fluid import layers from ..fluid import layers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册