diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index 716b5ee92d0d8737d2069460f53989f691ff7c77..bf7e88368157d29e627c3c06384f28b6e5e4ecc1 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel { REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); +REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_LOGICAL_OP(greater_than, "Out = X > Y"); +REGISTER_LOGICAL_KERNEL(greater_than, CPU, + paddle::operators::GreaterThanFunctor); +REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y"); +REGISTER_LOGICAL_KERNEL(greater_equal, CPU, + paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.cu b/paddle/operators/compare_op.cu index 42a5bb2f45fd389f60c3dc034cade7f56a907e35..6ac8c124b9b2e7c808808ecc8802a2e5aeaa5b5d 100644 --- a/paddle/operators/compare_op.cu +++ b/paddle/operators/compare_op.cu @@ -15,4 +15,9 @@ #include "paddle/operators/compare_op.h" REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor); +REGISTER_LOGICAL_KERNEL(greater_than, GPU, + paddle::operators::GreaterThanFunctor); +REGISTER_LOGICAL_KERNEL(greater_equal, GPU, + paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h index 04e04e347b398abb5fb66876bf801b1eee688ec6..afdf3ab3e098b4e7f4c996471617d97ec49264b1 100644 --- a/paddle/operators/compare_op.h +++ b/paddle/operators/compare_op.h @@ -27,6 +27,24 @@ struct LessThanFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } }; +template +struct LessEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } +}; + +template +struct GreaterThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct GreaterEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } +}; + template struct EqualFunctor { using ELEM_TYPE = T; diff --git a/python/paddle/v2/fluid/tests/test_compare_op.py b/python/paddle/v2/fluid/tests/test_compare_op.py index bb0256694d77323f12c50856533e93b090dc6198..5d0dfab6ffd1cbbbfbcdb3af60f1868b7b780456 100644 --- a/python/paddle/v2/fluid/tests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/test_compare_op.py @@ -23,6 +23,9 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) + create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) if __name__ == '__main__':