From 6f78cb996912d056c7df131838d2c0a79a018e19 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 10:34:24 +0800 Subject: [PATCH] add not_equal --- paddle/fluid/operators/compare_op.cc | 2 ++ paddle/fluid/operators/compare_op.cu | 1 + paddle/fluid/operators/compare_op.h | 8 ++++++++ python/paddle/v2/fluid/layers/math_op_patch.py | 3 ++- .../v2/fluid/tests/test_python_operator_overriding.py | 1 + 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index f3414c33b5..b1f09fb002 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_OP(not_equal, "Out = X != Y"); +REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 3507af2ae3..00263a2ade 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -17,3 +17,4 @@ limitations under the License. */ REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 4b2ee5a9d6..c651335268 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -48,6 +48,14 @@ struct EqualFunctor { } }; +template +struct NotEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return !EqualFunctor()(a, b); + } +}; + template class CompareOpKernel : public framework::OpKernel { diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 5301c3d1de..8208629af7 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -154,8 +154,9 @@ def monkey_patch_variable(): ("__rpow__", "elementwise_pow", True), # for logical compare ("__eq__", "equal", False), + ("__ne__", "not_equal", False), ("__lt__", "less_than", False), - ("__le__", "less_equal", False), ): + ("__le__", "less_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index aecae3332b..5ef0097388 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase): lambda _a, _b: _a > _b, lambda _a, _b: _a <= _b, lambda _a, _b: _a >= _b, + lambda _a, _b: _a != _b, ] # places to check -- GitLab