From c3a61349e4fd0dd98fe8fbe80d2553dffe5626a0 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 14 Nov 2017 22:18:31 +0530 Subject: [PATCH] Adding greater than and less than equal ops to compare op (#5609) * Adding greater than and less than equal ops to compare op * Changing the name of the less_than_equal and greater_than_equal op * Also changing the name of the functors --- paddle/operators/compare_op.cc | 8 ++++++++ paddle/operators/compare_op.cu | 5 +++++ paddle/operators/compare_op.h | 18 ++++++++++++++++++ .../paddle/v2/fluid/tests/test_compare_op.py | 3 +++ 4 files changed, 34 insertions(+) diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index 716b5ee92d..bf7e883681 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel { REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); +REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_LOGICAL_OP(greater_than, "Out = X > Y"); +REGISTER_LOGICAL_KERNEL(greater_than, CPU, + paddle::operators::GreaterThanFunctor); +REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y"); +REGISTER_LOGICAL_KERNEL(greater_equal, CPU, + paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.cu b/paddle/operators/compare_op.cu index 42a5bb2f45..6ac8c124b9 100644 --- a/paddle/operators/compare_op.cu +++ b/paddle/operators/compare_op.cu @@ -15,4 +15,9 @@ #include "paddle/operators/compare_op.h" REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor); +REGISTER_LOGICAL_KERNEL(greater_than, GPU, + paddle::operators::GreaterThanFunctor); +REGISTER_LOGICAL_KERNEL(greater_equal, GPU, + paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h index 04e04e347b..afdf3ab3e0 100644 --- a/paddle/operators/compare_op.h +++ b/paddle/operators/compare_op.h @@ -27,6 +27,24 @@ struct LessThanFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } }; +template +struct LessEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } +}; + +template +struct GreaterThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct GreaterEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } +}; + template struct EqualFunctor { using ELEM_TYPE = T; diff --git a/python/paddle/v2/fluid/tests/test_compare_op.py b/python/paddle/v2/fluid/tests/test_compare_op.py index bb0256694d..5d0dfab6ff 100644 --- a/python/paddle/v2/fluid/tests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/test_compare_op.py @@ -23,6 +23,9 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) + create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) if __name__ == '__main__': -- GitLab