From d4e3495cf5978fff9eb80b944c948598cdf48d3a Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 24 Feb 2018 13:07:46 +0800 Subject: [PATCH] add larger_than and larger_equal op and kernel --- paddle/fluid/operators/compare_op.cc | 5 +++++ paddle/fluid/operators/compare_op.cu | 4 ++++ paddle/fluid/operators/compare_op.h | 12 ++++++++++++ 3 files changed, 21 insertions(+) diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index 46d6d0fd8..a5f40d548 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -100,6 +100,11 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_OP(larger_than, "Out = X > Y"); +REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor); +REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y"); +REGISTER_COMPARE_KERNEL(larger_equal, CPU, + paddle::operators::LargerEqualFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index c6c83b44b..32d92a906 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -16,5 +16,9 @@ limitations under the License. */ REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_KERNEL(larger_than, CUDA, + paddle::operators::LargerThanFunctor); +REGISTER_COMPARE_KERNEL(larger_equal, CUDA, + paddle::operators::LargerEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 6638e5ae9..b4546f27b 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -34,6 +34,18 @@ struct LessEqualFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } }; +template +struct LargerThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct LargerEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } +}; + template struct EqualFunctor { using ELEM_TYPE = T; -- GitLab