diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index eee6cf56407741bf92b2cc716012fefafd8e55cb..84da69ed5da027b3ba10e4702b061fdf4bc6d2c6 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1878,12 +1878,17 @@ struct CudaCosGradFunctor : public BaseActivationFunctor { template struct CudaExpFunctor : public BaseActivationFunctor { - using MPType = typename phi::dtype::MPTypeTrait::Type; + // exp(x) = expf(x) + __device__ __forceinline__ T operator()(const T x) const { + return static_cast(expf(static_cast(x))); + } +}; +template <> +struct CudaExpFunctor : public BaseActivationFunctor { // exp(x) = exp(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(exp(x)); + __device__ __forceinline__ double operator()(const double x) const { + return exp(x); } };