From 28d07e3cb88737a0abe95d3ed4b4660e3a832dc2 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 24 Feb 2018 13:15:05 +0800 Subject: [PATCH] add python part of compare op --- paddle/fluid/operators/compare_op.cc | 11 ++++++----- paddle/fluid/operators/compare_op.cu | 8 ++++---- paddle/fluid/operators/compare_op.h | 4 ++-- python/paddle/v2/fluid/layers/math_op_patch.py | 4 +++- .../v2/fluid/tests/unittests/test_compare_op.py | 3 +++ 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index a5f40d54829..86f7046058c 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -100,11 +100,12 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); -REGISTER_COMPARE_OP(larger_than, "Out = X > Y"); -REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor); -REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y"); -REGISTER_COMPARE_KERNEL(larger_equal, CPU, - paddle::operators::LargerEqualFunctor); +REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); +REGISTER_COMPARE_KERNEL(greater_than, CPU, + paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); +REGISTER_COMPARE_KERNEL(greater_equal, CPU, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 32d92a90666..1bf85c64fb5 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -16,9 +16,9 @@ limitations under the License. */ REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); -REGISTER_COMPARE_KERNEL(larger_than, CUDA, - paddle::operators::LargerThanFunctor); -REGISTER_COMPARE_KERNEL(larger_equal, CUDA, - paddle::operators::LargerEqualFunctor); +REGISTER_COMPARE_KERNEL(greater_than, CUDA, + paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_KERNEL(greater_equal, CUDA, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index b4546f27b1e..1cbabdaf676 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -35,13 +35,13 @@ struct LessEqualFunctor { }; template -struct LargerThanFunctor { +struct GreaterThanFunctor { using ELEM_TYPE = T; HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } }; template -struct LargerEqualFunctor { +struct GreaterEqualFunctor { using ELEM_TYPE = T; HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } }; diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 417a01b76f1..c92eb94f09c 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -157,7 +157,9 @@ def monkey_patch_variable(): ("__eq__", "equal", False), ("__ne__", "not_equal", False), ("__lt__", "less_than", False), - ("__le__", "less_equal", False)): + ("__le__", "less_equal", False), + ("__gt__", "greater_than", False), + ("__ge__", "greater_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/tests/unittests/test_compare_op.py b/python/paddle/v2/fluid/tests/unittests/test_compare_op.py index 83d57639ca4..405afebae85 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_compare_op.py @@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) if __name__ == '__main__': unittest.main() -- GitLab