未验证 提交 58ae8c7c 编写于 作者: H Hui Zhang 提交者: GitHub

exp/expm1 support int32/int64/float16 forward (#54556)

* fix for log xxx

* add int32/int64 for cpu/gpu; add float16/bfloat16 for cpu forward

* fix docstring

* fix bug

* fix bugs

* fix bugs

* fix bugs

* fix bugs

* fix bug

* using cast

* fix test

* fix api

* fix other bugs

* fix ci bug for not using dygraph guard

* add bfloat16 test

* fix ut

* bf16

* exp/expm1 support int32/int64

* fix ut

* fix ut

* fix ut
上级 9a36fd4b
......@@ -86,8 +86,6 @@ DEFINE_CPU_ACTIVATION_KERNEL(Relu, ReluCPUFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Tanh, TanhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Silu, SiluFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Exp, ExpFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Expm1, Expm1Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, ReciprocalFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Square, SquareFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, SqrtFunctor)
......@@ -104,6 +102,8 @@ DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, Log2Functor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, Log10Functor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, Log1pFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, ExpFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, Expm1Functor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
......@@ -175,15 +175,26 @@ PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_KERNEL(
exp, CPU, ALL_LAYOUT, phi::ExpKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(exp,
CPU,
ALL_LAYOUT,
phi::ExpKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(expm1,
CPU,
ALL_LAYOUT,
phi::Expm1Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
......
......@@ -1074,9 +1074,11 @@ struct AtanhGradFunctor : public BaseActivationFunctor<T> {
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.exp();
out.device(d) = x.template cast<U>().exp();
}
};
......@@ -1099,9 +1101,11 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
// expm1(x) = e^x - 1
template <typename T>
struct Expm1Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.expm1();
out.device(d) = x.template cast<U>().expm1();
}
};
......@@ -2668,8 +2672,10 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
// exp(x) = expf(x)
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(expf(static_cast<float>(x)));
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
__device__ __forceinline__ U operator()(const T x) const {
return static_cast<U>(expf(static_cast<float>(x)));
}
};
......@@ -2781,12 +2787,19 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// expm1(x) = expm1f(x)
__device__ __forceinline__ U operator()(const T x) const {
return static_cast<U>(::expm1f(static_cast<float>(x)));
}
};
template <>
struct CudaExpm1Functor<double> : public BaseActivationFunctor<double> {
// expm1(x) = expm1(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(expm1(x));
__device__ __forceinline__ double operator()(const double x) const {
return ::expm1(x);
}
};
......
......@@ -103,8 +103,6 @@ DEFINE_GPU_ACTIVATION_KERNEL(Relu, CudaReluFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Tanh, CudaTanhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Silu, CudaSiluFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Exp, CudaExpFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Expm1, CudaExpm1Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Reciprocal, CudaReciprocalFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Square, CudaSquareFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sqrt, CudaSqrtFunctor)
......@@ -120,6 +118,8 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, CudaLog2Functor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, CudaLog10Functor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, CudaLog1pFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, CudaExpFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps)
......@@ -237,6 +237,8 @@ PD_REGISTER_KERNEL(expm1,
phi::Expm1Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square,
......
......@@ -538,7 +538,7 @@ def exp(x, name=None):
out = e^x
Args:
x (Tensor): Input of Exp operator, an N-D Tensor, with data type float32, float64 or float16.
x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -570,7 +570,6 @@ def exp(x, name=None):
'float64',
'complex64',
'complex128',
'uint16',
],
'exp',
)
......@@ -589,7 +588,7 @@ def expm1(x, name=None):
out = e^x - 1
Args:
x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type float32, float64 or float16.
x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -610,7 +609,10 @@ def expm1(x, name=None):
return _C_ops.expm1(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'expm1'
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'expm1',
)
helper = LayerHelper('expm1', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
......@@ -154,6 +154,35 @@ class TestExpPrim_ZeroDim(TestExpFp32_Prim):
self.shape = []
class Test_Exp_Op_Fp16(unittest.TestCase):
def test_api_fp16(self):
with paddle.fluid.framework._static_guard():
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
np_x = np.array([[2, 3, 4], [7, 8, 9]])
x = paddle.to_tensor(np_x, dtype='float16')
out = paddle.exp(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
x_expect = np.exp(np_x.astype('float16'))
np.testing.assert_allclose(res, x_expect, rtol=1e-3)
class Test_Exp_Op_Int(unittest.TestCase):
def test_api_int(self):
paddle.disable_static()
for dtype in ('int32', 'int64', 'float16'):
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.exp(x)
x_expect = np.exp(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)
paddle.enable_static()
class TestExpm1(TestActivation):
def setUp(self):
self.op_type = "expm1"
......@@ -222,12 +251,17 @@ class TestExpm1API(unittest.TestCase):
for place in self.place:
run(place)
def test_errors(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
X = paddle.static.data('X', self.shape, dtype='int32')
self.assertRaises(TypeError, paddle.expm1, X)
# The input dtype must be float16, float32, float64.
class Test_Expm1_Op_Int(unittest.TestCase):
def test_api_int(self):
paddle.disable_static()
for dtype in ('int32', 'int64', 'float16'):
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.expm1(x)
x_expect = np.expm1(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)
paddle.enable_static()
class TestParameter:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册