diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index a5f40d54829b3feca93667e53efdfed7619633e0..86f7046058c7001fcaa588727b1cdc0f3f20c35f 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -100,11 +100,12 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); -REGISTER_COMPARE_OP(larger_than, "Out = X > Y"); -REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor); -REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y"); -REGISTER_COMPARE_KERNEL(larger_equal, CPU, - paddle::operators::LargerEqualFunctor); +REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); +REGISTER_COMPARE_KERNEL(greater_than, CPU, + paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); +REGISTER_COMPARE_KERNEL(greater_equal, CPU, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 32d92a9066677c1e526b7d8ce604b40103bd1ed4..1bf85c64fb5b4d79c62118959fd72b13ed1c63ed 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -16,9 +16,9 @@ limitations under the License. */ REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); -REGISTER_COMPARE_KERNEL(larger_than, CUDA, - paddle::operators::LargerThanFunctor); -REGISTER_COMPARE_KERNEL(larger_equal, CUDA, - paddle::operators::LargerEqualFunctor); +REGISTER_COMPARE_KERNEL(greater_than, CUDA, + paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_KERNEL(greater_equal, CUDA, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index b4546f27b1ef831310f3616df6c1e403d493a577..1cbabdaf6767815c1fedba0eabec9b5de678e047 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -35,13 +35,13 @@ struct LessEqualFunctor { }; template -struct LargerThanFunctor { +struct GreaterThanFunctor { using ELEM_TYPE = T; HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } }; template -struct LargerEqualFunctor { +struct GreaterEqualFunctor { using ELEM_TYPE = T; HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } }; diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 417a01b76f16336d38a3f7589f660b1a7779594e..c92eb94f09c0232e2f55e410fc033e10ee2788cc 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -157,7 +157,9 @@ def monkey_patch_variable(): ("__eq__", "equal", False), ("__ne__", "not_equal", False), ("__lt__", "less_than", False), - ("__le__", "less_equal", False)): + ("__le__", "less_equal", False), + ("__gt__", "greater_than", False), + ("__ge__", "greater_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/tests/unittests/test_compare_op.py b/python/paddle/v2/fluid/tests/unittests/test_compare_op.py index 83d57639ca4b2dab8ce62a23551161dc2766783a..405afebae85eaae6f6af0012058ad58c8bb69a2f 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_compare_op.py @@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) if __name__ == '__main__': unittest.main()