diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index f3414c33b5ab3cc8dffee640fd85b9625b3f237b..b1f09fb0029affe671d63874cf3d3db86476c367 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_OP(not_equal, "Out = X != Y"); +REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 3507af2ae3add8cf02f5b9f3b3d89b40d73bcb0d..00263a2ade4502e732d53b871665185f8d0fa9f1 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -17,3 +17,4 @@ limitations under the License. */ REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 4b2ee5a9d68f5f1fd3d2d374669763855659f1db..c651335268fee08c08bcac6247f5a2ff92784330 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -48,6 +48,14 @@ struct EqualFunctor { } }; +template +struct NotEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return !EqualFunctor()(a, b); + } +}; + template class CompareOpKernel : public framework::OpKernel { diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 5301c3d1ded091a203b62e9e52dceca022e68683..8208629af7931a32357ad07ac9d7110d0727fc0e 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -154,8 +154,9 @@ def monkey_patch_variable(): ("__rpow__", "elementwise_pow", True), # for logical compare ("__eq__", "equal", False), + ("__ne__", "not_equal", False), ("__lt__", "less_than", False), - ("__le__", "less_equal", False), ): + ("__le__", "less_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index aecae3332bbaecce478e89a38228db122e1e7e50..5ef0097388057de79a85e0c0548c18bd01676d2d 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase): lambda _a, _b: _a > _b, lambda _a, _b: _a <= _b, lambda _a, _b: _a >= _b, + lambda _a, _b: _a != _b, ] # places to check