diff --git a/paddle/fluid/operators/controlflow/bitwise_op.h b/paddle/fluid/operators/controlflow/bitwise_op.h index 92abe4cd3b1c3630ed9c2652f2ff8a49f033f13b..9e652f92007479684fcf8ec5e539312d8d729107 100644 --- a/paddle/fluid/operators/controlflow/bitwise_op.h +++ b/paddle/fluid/operators/controlflow/bitwise_op.h @@ -22,19 +22,19 @@ limitations under the License. */ namespace paddle { namespace operators { -#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ - template \ - struct Bitwise##func##Functor { \ - using ELEM_TYPE = T; \ - HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \ - }; \ - \ - template <> \ - struct Bitwise##func##Functor { \ - using ELEM_TYPE = bool; \ - HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \ - return a bool_expr b; \ - } \ +#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ + template \ + struct Bitwise##func##Functor { \ + using ELEM_TYPE = T; \ + HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \ + }; \ + \ + template <> \ + struct Bitwise##func##Functor { \ + using ELEM_TYPE = bool; \ + HOSTDEVICE bool operator()(const bool a, const bool b) const { \ + return a bool_expr b; \ + } \ }; BITWISE_BINARY_FUNCTOR(And, &, &&) @@ -45,13 +45,13 @@ BITWISE_BINARY_FUNCTOR(Xor, ^, !=) template struct BitwiseNotFunctor { using ELEM_TYPE = T; - HOSTDEVICE T operator()(const T& a) const { return ~a; } + HOSTDEVICE T operator()(const T a) const { return ~a; } }; template <> struct BitwiseNotFunctor { using ELEM_TYPE = bool; - HOSTDEVICE bool operator()(const bool& a) const { return !a; } + HOSTDEVICE bool operator()(const bool a) const { return !a; } }; template diff --git a/paddle/fluid/operators/controlflow/compare_all_op.h b/paddle/fluid/operators/controlflow/compare_all_op.h index bcad240601cf6778c8eeaabc50f262b3ee5e938d..78a7b76e3fd9d03f2381dfb13f90c191d1dca4f8 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.h +++ b/paddle/fluid/operators/controlflow/compare_all_op.h @@ -28,7 +28,7 @@ namespace operators { template struct EqualReduceFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { + HOSTDEVICE bool operator()(const T a, const T b) const { if (std::is_floating_point::value) { // This branch will be optimized while compiling if T is integer. It is // safe to cast a and b to double. diff --git a/paddle/fluid/operators/controlflow/compare_op.h b/paddle/fluid/operators/controlflow/compare_op.h index 36185322a96b8909c49e1a3c5a55afa47d4952bc..d2ef4c9befba99290008508e43df6c84f969b710 100644 --- a/paddle/fluid/operators/controlflow/compare_op.h +++ b/paddle/fluid/operators/controlflow/compare_op.h @@ -25,31 +25,31 @@ namespace operators { template struct LessThanFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; } }; template struct LessEqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; } }; template struct GreaterThanFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; } }; template struct GreaterEqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } + HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; } }; template struct EqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { + HOSTDEVICE bool operator()(const T a, const T b) const { if (std::is_floating_point::value) { // This branch will be optimized while compiling if T is integer. It is // safe to cast a and b to double. @@ -63,7 +63,7 @@ struct EqualFunctor { template struct NotEqualFunctor { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { + HOSTDEVICE bool operator()(const T a, const T b) const { return !EqualFunctor()(a, b); } }; diff --git a/paddle/fluid/operators/controlflow/logical_op.cu b/paddle/fluid/operators/controlflow/logical_op.cu index 301b4c4149fad3be6ae121b7985e61dd42ef7c36..4a3fc6c895174c088fc98a017515c58101cd4d70 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cu +++ b/paddle/fluid/operators/controlflow/logical_op.cu @@ -18,26 +18,6 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -#define LOGICAL_BINARY_FUNCTOR(func_name, op) \ - template \ - struct func_name { \ - using ELEMENT_TYPE = T; \ - HOSTDEVICE bool operator()(const T* args) const { \ - return static_cast(args[0]) op static_cast(args[1]); \ - } \ - }; - -LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||) -LOGICAL_BINARY_FUNCTOR(CudaAndFunctor, &&) -LOGICAL_BINARY_FUNCTOR(CudaXorFunctor, ^) -#undef LOGICAL_BINARY_FUNCTOR - -template -struct CudaNotFunctor { - using ELEMENT_TYPE = T; - HOSTDEVICE bool operator()(const T* args) const { return !args[0]; } -}; - template class BinaryLogicalOpKernel : public framework::OpKernel { @@ -76,8 +56,8 @@ class BinaryLogicalOpKernel ops::BinaryLogicalOpKernel>, \ ops::BinaryLogicalOpKernel>); -REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor) -REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor) -REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, CudaXorFunctor) -REGISTER_LOGICAL_CUDA_KERNEL(logical_not, CudaNotFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_or, LogicalOrFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_and, LogicalAndFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, LogicalXorFunctor) +REGISTER_LOGICAL_CUDA_KERNEL(logical_not, LogicalNotFunctor) #undef REGISTER_LOGICAL_CUDA_KERNEL diff --git a/paddle/fluid/operators/controlflow/logical_op.h b/paddle/fluid/operators/controlflow/logical_op.h index 92fe0a10cb907c333954f51b06a199a6c23cffbe..ee63da60fcd0fea223414d10d74f84f52e9e9e45 100644 --- a/paddle/fluid/operators/controlflow/logical_op.h +++ b/paddle/fluid/operators/controlflow/logical_op.h @@ -19,38 +19,32 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct LogicalAndFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; } -}; +#define LOGICAL_BINARY_FUNCTOR(func_name, op) \ + template \ + struct func_name { \ + using ELEMENT_TYPE = T; \ + HOSTDEVICE bool operator()(const T a, const T b) const { \ + return static_cast(a) op static_cast(b); \ + } \ + }; -template -struct LogicalOrFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; } -}; +LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||) +LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&) +LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^) +#undef LOGICAL_BINARY_FUNCTOR template struct LogicalNotFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a) const { return !a; } -}; - -template -struct LogicalXorFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { - return (a || b) && !(a && b); - } + using ELEMENT_TYPE = T; + HOSTDEVICE bool operator()(const T a) const { return !a; } }; template class BinaryLogicalOpKernel - : public framework::OpKernel { + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; + using T = typename Functor::ELEMENT_TYPE; auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* out = context.Output("Out"); @@ -62,10 +56,10 @@ class BinaryLogicalOpKernel template class UnaryLogicalOpKernel - : public framework::OpKernel { + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; + using T = typename Functor::ELEMENT_TYPE; auto* x = context.Input("X"); auto* out = context.Output("Out"); Functor unary_func;