diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index 46d6d0fd862912b5b4fb5aaaedf4526de2dcdd98..a5f40d54829b3feca93667e53efdfed7619633e0 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -100,6 +100,11 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_OP(larger_than, "Out = X > Y"); +REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor); +REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y"); +REGISTER_COMPARE_KERNEL(larger_equal, CPU, + paddle::operators::LargerEqualFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index c6c83b44b8ab6f36205c852e34aaf31a940e1888..32d92a9066677c1e526b7d8ce604b40103bd1ed4 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -16,5 +16,9 @@ limitations under the License. */ REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_KERNEL(larger_than, CUDA, + paddle::operators::LargerThanFunctor); +REGISTER_COMPARE_KERNEL(larger_equal, CUDA, + paddle::operators::LargerEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 6638e5ae9e6ed28b26329d97fb06716e3a39b276..b4546f27b1ef831310f3616df6c1e403d493a577 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -34,6 +34,18 @@ struct LessEqualFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } }; +template +struct LargerThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct LargerEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } +}; + template struct EqualFunctor { using ELEM_TYPE = T;