From 0efcae8676869d923eb3beca5259549e8b0776a0 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Wed, 12 Jan 2022 21:09:38 +0800 Subject: [PATCH] [part 3]change type of function args (#38887) * code clean * [part 3]change type of function args --- .../fluid/operators/controlflow/bitwise_op.h | 30 ++++++------- .../operators/controlflow/compare_all_op.h | 2 +- .../fluid/operators/controlflow/compare_op.h | 12 +++--- .../fluid/operators/controlflow/logical_op.cu | 28 ++----------- .../fluid/operators/controlflow/logical_op.h | 42 ++++++++----------- 5 files changed, 44 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/controlflow/bitwise_op.h b/paddle/fluid/operators/controlflow/bitwise_op.h index 92abe4cd3b1..9e652f92007 100644 --- a/paddle/fluid/operators/controlflow/bitwise_op.h +++ b/paddle/fluid/operators/controlflow/bitwise_op.h @@ -22,19 +22,19 @@ limitations under the License. */ namespace paddle { namespace operators { -#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ - template \ - struct Bitwise##func##Functor { \ - using ELEM_TYPE = T; \ - HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \ - }; \ - \ - template <> \ - struct Bitwise##func##Functor { \ - using ELEM_TYPE = bool; \ - HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \ - return a bool_expr b; \ - } \ +#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ + template \ + struct Bitwise##func##Functor { \ + using ELEM_TYPE = T; \ + HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \ + }; \ + \ + template <> \ + struct Bitwise##func##Functor { \ + using ELEM_TYPE = bool; \ + HOSTDEVICE bool operator()(const bool a, const bool b) const { \ + return a bool_expr b; \ + } \ }; BITWISE_BINARY_FUNCTOR(And, &, &&) @@ -45,13 +45,13 @@ BITWISE_BINARY_FUNCTOR(Xor, ^, !=) template struct BitwiseNotFunctor { using ELEM_TYPE = T; - HOSTDEVICE T operator()(const T& a) const { return ~a; } + HOSTDEVICE T operator()(const T a) const { return ~a; } }; template <> struct BitwiseNotFunctor { using ELEM_TYPE = bool; - HOSTDEVICE bool operator()(const bool& a) const { return !a; } + HOSTDEVICE bool operator()(const bool a) const { return !a; } }; template diff --git a/paddle/fluid/operators/controlflow/compare_all_op.h b/paddle/fluid/operators/controlflow/compare_all_op.h index bcad240601c..78a7b76e3fd 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.h +++ b/paddle/fluid/operators/controlflow/compare_all_op.h @@ -28,7 +28,7 @@ namespace operators { template struct EqualReduceFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { + HOSTDEVICE bool operator()(const T a, const T b) const { if (std::is_floating_point::value) { // This branch will be optimized while compiling if T is integer. It is // safe to cast a and b to double. diff --git a/paddle/fluid/operators/controlflow/compare_op.h b/paddle/fluid/operators/controlflow/compare_op.h index 36185322a96..d2ef4c9befb 100644 --- a/paddle/fluid/operators/controlflow/compare_op.h +++ b/paddle/fluid/operators/controlflow/compare_op.h @@ -25,31 +25,31 @@ namespace operators { template struct LessThanFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; } }; template struct LessEqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; } }; template struct GreaterThanFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; } }; template struct GreaterEqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; } }; template struct EqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { + HOSTDEVICE bool operator()(const T a, const T b) const { if (std::is_floating_point::value) { // This branch will be optimized while compiling if T is integer. It is // safe to cast a and b to double. @@ -63,7 +63,7 @@ struct EqualFunctor { template struct NotEqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { + HOSTDEVICE bool operator()(const T a, const T b) const { return !EqualFunctor()(a, b); } }; diff --git a/paddle/fluid/operators/controlflow/logical_op.cu b/paddle/fluid/operators/controlflow/logical_op.cu index 301b4c4149f..4a3fc6c8951 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cu +++ b/paddle/fluid/operators/controlflow/logical_op.cu @@ -18,26 +18,6 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -#define LOGICAL_BINARY_FUNCTOR(func_name, op) \ - template \ - struct func_name { \ - using ELEMENT_TYPE = T; \ - HOSTDEVICE bool operator()(const T* args) const { \ - return static_cast(args[0]) op static_cast(args[1]); \ - } \ - }; - -LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||) -LOGICAL_BINARY_FUNCTOR(CudaAndFunctor, &&) -LOGICAL_BINARY_FUNCTOR(CudaXorFunctor, ^) -#undef LOGICAL_BINARY_FUNCTOR - -template -struct CudaNotFunctor { - using ELEMENT_TYPE = T; - HOSTDEVICE bool operator()(const T* args) const { return !args[0]; } -}; - template class BinaryLogicalOpKernel : public framework::OpKernel { @@ -76,8 +56,8 @@ class BinaryLogicalOpKernel ops::BinaryLogicalOpKernel>, \ ops::BinaryLogicalOpKernel>); -REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor) -REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor) -REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, CudaXorFunctor) -REGISTER_LOGICAL_CUDA_KERNEL(logical_not, CudaNotFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_or, LogicalOrFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_and, LogicalAndFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, LogicalXorFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_not, LogicalNotFunctor) #undef REGISTER_LOGICAL_CUDA_KERNEL diff --git a/paddle/fluid/operators/controlflow/logical_op.h b/paddle/fluid/operators/controlflow/logical_op.h index 92fe0a10cb9..ee63da60fcd 100644 --- a/paddle/fluid/operators/controlflow/logical_op.h +++ b/paddle/fluid/operators/controlflow/logical_op.h @@ -19,38 +19,32 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct LogicalAndFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; } -}; +#define LOGICAL_BINARY_FUNCTOR(func_name, op) \ + template \ + struct func_name { \ + using ELEMENT_TYPE = T; \ + HOSTDEVICE bool operator()(const T a, const T b) const { \ + return static_cast(a) op static_cast(b); \ + } \ + }; -template -struct LogicalOrFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; } -}; +LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||) +LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&) +LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^) +#undef LOGICAL_BINARY_FUNCTOR template struct LogicalNotFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a) const { return !a; } -}; - -template -struct LogicalXorFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { - return (a || b) && !(a && b); - } + using ELEMENT_TYPE = T; + HOSTDEVICE bool operator()(const T a) const { return !a; } }; template class BinaryLogicalOpKernel - : public framework::OpKernel { + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; + using T = typename Functor::ELEMENT_TYPE; auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* out = context.Output("Out"); @@ -62,10 +56,10 @@ class BinaryLogicalOpKernel template class UnaryLogicalOpKernel - : public framework::OpKernel { + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; + using T = typename Functor::ELEMENT_TYPE; auto* x = context.Input("X"); auto* out = context.Output("Out"); Functor unary_func; -- GitLab