diff --git a/paddle/phi/kernels/primitive/functor_primitives_xpu2.h b/paddle/phi/kernels/primitive/functor_primitives_xpu2.h index fdcbb5ec9cc8d0a6a678ec9979524ba0c5a5e7a8..35cea6e6927871c6d47049d0fe0523f655301723 100644 --- a/paddle/phi/kernels/primitive/functor_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/functor_primitives_xpu2.h @@ -55,21 +55,21 @@ struct DivideFunctor { inline DivideFunctor() { n_inv = static_cast(1.0f); } explicit inline DivideFunctor(int n) - : n_inv(static_cast(((float)1.0) / (static_cast(n)))) {} + : n_inv(static_cast(1.0f / (static_cast(n)))) {} inline Ty operator()(const Tx& x) const { return static_cast(x * n_inv); } __device__ inline DivideFunctor() { n_inv = static_cast(1.0f); } __device__ inline DivideFunctor(int n) - : n_inv(static_cast(((float)1.0) / (static_cast(n)))) {} + : n_inv(static_cast(1.0f / (static_cast(n)))) {} __device__ inline Ty operator()(const Tx& x) const { return static_cast(x * n_inv); } __device__ inline void SetDiv(int n) { - n_inv = static_cast(((float)1.0) / (static_cast(n))); + n_inv = static_cast(1.0f / (static_cast(n))); } private: @@ -97,8 +97,7 @@ struct SquareFunctor { */ template struct MinFunctor { - inline T initial() { /*return static_cast(std::numeric_limits::max());*/ - } + inline T initial() { return static_cast(std::numeric_limits::max()); } __device__ T operator()(const T& a, const T& b) const { return (b < a) ? b : a; @@ -111,7 +110,7 @@ struct MinFunctor { template struct MaxFunctor { inline T initial() { - // return static_cast(std::numeric_limits::lowest()); + return static_cast(std::numeric_limits::lowest()); } __device__ T operator()(const T& a, const T& b) const { @@ -124,8 +123,7 @@ struct MaxFunctor { */ template struct AddFunctor { - inline T initial() { /*return static_cast(0.0f);*/ - } + inline T initial() { return static_cast(0.0f); } __device__ T operator()(const T a, const T b) const { return b + a; } }; @@ -135,8 +133,7 @@ struct AddFunctor { */ template struct MulFunctor { - inline T initial() { /*return static_cast(1.0f);*/ - } + inline T initial() { return static_cast(1.0f); } __device__ T operator()(const T& a, const T& b) const { return b * a; } }; @@ -146,8 +143,7 @@ struct MulFunctor { */ template struct LogicalOrFunctor { - inline T initial() { /*return static_cast(false);*/ - } + inline T initial() { return static_cast(false); } __device__ T operator()(const T& a, const T& b) const { return b || a; } };