From c72ed8244e8cbc61dd8d0a7f26b76295f3e35c98 Mon Sep 17 00:00:00 2001 From: wawltor Date: Tue, 18 May 2021 10:17:32 +0800 Subject: [PATCH] fix the paddle compare op for the broadcast when the element equal (#32941) * fix the paddle compare op for the broadcast * fix compare op in for in the cuda device --- paddle/fluid/operators/controlflow/compare_op.cc | 8 ++++---- paddle/fluid/operators/controlflow/compare_op.cu | 8 ++++---- .../fluid/tests/unittests/test_compare_op.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index bf047de86fc..a03e4165755 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -131,18 +131,18 @@ class CompareOp : public framework::OperatorWithKernel { REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor, - paddle::operators::GreaterEqualFunctor); + paddle::operators::GreaterThanFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor, - paddle::operators::GreaterThanFunctor); + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); REGISTER_COMPARE_KERNEL(greater_than, CPU, paddle::operators::GreaterThanFunctor, - paddle::operators::LessEqualFunctor); + paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); REGISTER_COMPARE_KERNEL(greater_equal, CPU, paddle::operators::GreaterEqualFunctor, - paddle::operators::LessThanFunctor); + paddle::operators::LessEqualFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor, paddle::operators::EqualFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index 3ca700e16e6..a60201f9d07 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -15,15 +15,15 @@ limitations under the License. */ #include "paddle/fluid/operators/controlflow/compare_op.h" REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor, - paddle::operators::GreaterEqualFunctor); -REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor, paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_KERNEL(greater_than, CUDA, paddle::operators::GreaterThanFunctor, - paddle::operators::LessEqualFunctor); + paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(greater_equal, CUDA, paddle::operators::GreaterEqualFunctor, - paddle::operators::LessThanFunctor); + paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor, diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 8dc80c89312..a2dd7e49ac4 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -139,6 +139,22 @@ def create_paddle_case(op_type, callback): fetch_list=[out]) self.assertEqual((res == real_result).all(), True) + def test_broadcast_api_3(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[5], dtype='int32') + y = paddle.static.data(name='y', shape=[3, 1], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 5).reshape((5)).astype(np.int32) + input_y = np.array([5, 3, 2]).reshape((3, 1)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + def test_attr_name(self): paddle.enable_static() with program_guard(Program(), Program()): -- GitLab