diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index bf047de86fc21a4d5d9e9ff8f20c9a1982eb25af..a03e4165755dde3211425b028b474896249237f7 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -131,18 +131,18 @@ class CompareOp : public framework::OperatorWithKernel { REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor, - paddle::operators::GreaterEqualFunctor); + paddle::operators::GreaterThanFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor, - paddle::operators::GreaterThanFunctor); + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); REGISTER_COMPARE_KERNEL(greater_than, CPU, paddle::operators::GreaterThanFunctor, - paddle::operators::LessEqualFunctor); + paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); REGISTER_COMPARE_KERNEL(greater_equal, CPU, paddle::operators::GreaterEqualFunctor, - paddle::operators::LessThanFunctor); + paddle::operators::LessEqualFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor, paddle::operators::EqualFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index 3ca700e16e6e7bcf4136ca68dd895593a63824ec..a60201f9d07d69897ec81ced54964a50a9d84795 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -15,15 +15,15 @@ limitations under the License. */ #include "paddle/fluid/operators/controlflow/compare_op.h" REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor, - paddle::operators::GreaterEqualFunctor); -REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor, paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_KERNEL(greater_than, CUDA, paddle::operators::GreaterThanFunctor, - paddle::operators::LessEqualFunctor); + paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(greater_equal, CUDA, paddle::operators::GreaterEqualFunctor, - paddle::operators::LessThanFunctor); + paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor, diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 8dc80c893126925ff0643b4cde622fe504c6b1d9..a2dd7e49ac4ccdd6135a27d7b88f6fbdec2132b9 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -139,6 +139,22 @@ def create_paddle_case(op_type, callback): fetch_list=[out]) self.assertEqual((res == real_result).all(), True) + def test_broadcast_api_3(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[5], dtype='int32') + y = paddle.static.data(name='y', shape=[3, 1], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 5).reshape((5)).astype(np.int32) + input_y = np.array([5, 3, 2]).reshape((3, 1)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + def test_attr_name(self): paddle.enable_static() with program_guard(Program(), Program()):