diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index cdeb28cc1db2c3328d8f605e6584f5b5e1311e97..86f7046058c7001fcaa588727b1cdc0f3f20c35f 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -83,7 +83,7 @@ class CompareOp : public framework::OperatorWithKernel { } // namespace operators } // namespace paddle -#define REGISTER_LOGICAL_OP(op_type, _equation) \ +#define REGISTER_COMPARE_OP(op_type, _equation) \ struct _##op_type##Comment { \ static char type[]; \ static char equation[]; \ @@ -96,11 +96,17 @@ class CompareOp : public framework::OperatorWithKernel { ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker); -REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); -REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); -REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); -REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); -REGISTER_LOGICAL_OP(equal, "Out = X == Y"); -REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); -REGISTER_LOGICAL_OP(not_equal, "Out = X != Y"); -REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); +REGISTER_COMPARE_OP(less_than, "Out = X < Y"); +REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); +REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); +REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); +REGISTER_COMPARE_KERNEL(greater_than, CPU, + paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); +REGISTER_COMPARE_KERNEL(greater_equal, CPU, + paddle::operators::GreaterEqualFunctor); +REGISTER_COMPARE_OP(equal, "Out = X == Y"); +REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); +REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); +REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 2cc0c7c57257ad5bd89606a5e29f17156cbda773..1bf85c64fb5b4d79c62118959fd72b13ed1c63ed 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -14,7 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/compare_op.h" -REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); -REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); -REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); -REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); +REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); +REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_KERNEL(greater_than, CUDA, + paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_KERNEL(greater_equal, CUDA, + paddle::operators::GreaterEqualFunctor); +REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); +REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 7e78269cf4767eb78cde1decaedc96a3c921716a..1cbabdaf6767815c1fedba0eabec9b5de678e047 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -34,6 +34,18 @@ struct LessEqualFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } }; +template +struct GreaterThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct GreaterEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } +}; + template struct EqualFunctor { using ELEM_TYPE = T; @@ -76,7 +88,7 @@ class CompareOpKernel } // namespace operators } // namespace paddle -#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \ +#define REGISTER_COMPARE_KERNEL(op_type, dev, functor) \ REGISTER_OP_##dev##_KERNEL( \ op_type, ::paddle::operators::CompareOpKernel< \ ::paddle::platform::dev##DeviceContext, functor>, \ diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index beebc1a85f88511822e7f8ad4cd62fc024318430..faccc3ddf827e4211c9f2e61da7138e5d43f1d11 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -157,7 +157,9 @@ def monkey_patch_variable(): ("__eq__", "equal", False), ("__ne__", "not_equal", False), ("__lt__", "less_than", False), - ("__le__", "less_equal", False)): + ("__le__", "less_equal", False), + ("__gt__", "greater_than", False), + ("__ge__", "greater_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/tests/unittests/test_compare_op.py b/python/paddle/v2/fluid/tests/unittests/test_compare_op.py index 83d57639ca4b2dab8ce62a23551161dc2766783a..405afebae85eaae6f6af0012058ad58c8bb69a2f 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_compare_op.py @@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) if __name__ == '__main__': unittest.main()