diff --git a/paddle/phi/kernels/xpu/compare_kernel.cc b/paddle/phi/kernels/xpu/compare_kernel.cc index 32866e7aa701e7b124a1f22043078fe8b9811233..bda9e81c2a1565d7d7c17da996faab8a195b9fd9 100644 --- a/paddle/phi/kernels/xpu/compare_kernel.cc +++ b/paddle/phi/kernels/xpu/compare_kernel.cc @@ -36,6 +36,13 @@ void XPUCompareKernelImpl(const Context& dev_ctx, auto x_shape = vectorize(x.dims()); auto y_shape = vectorize(y.dims()); + if (x.dims().size() == 0) { + x_shape = std::vector({1}); + } + if (y.dims().size() == 0) { + y_shape = std::vector({1}); + } + auto x_data = reinterpret_cast(x.data()); auto y_data = reinterpret_cast(y.data()); auto* out_data = dev_ctx.template Alloc(out); diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index c5b69f8c59af6aebb20473c13cb50651e6b815fc..2a598cae044169e2b86e5a0ea6e4609b572bd98e 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -283,54 +283,6 @@ def create_paddle_case(op_type, callback): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() - def test_zero_dim_api_1(self): - paddle.enable_static() - with program_guard(Program(), Program()): - x = paddle.randint(-3, 3, shape=[], dtype='int32') - y = paddle.randint(-3, 3, shape=[], dtype='int32') - op = eval("paddle.%s" % (self.op_type)) - out = op(x, y) - exe = paddle.static.Executor(self.place) - ( - x_np, - y_np, - res, - ) = exe.run(fetch_list=[x, y, out]) - real_result = callback(x_np, y_np) - self.assertEqual((res == real_result).all(), True) - - def test_zero_dim_api_2(self): - paddle.enable_static() - with program_guard(Program(), Program()): - x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') - y = paddle.randint(-3, 3, shape=[], dtype='int32') - op = eval("paddle.%s" % (self.op_type)) - out = op(x, y) - exe = paddle.static.Executor(self.place) - ( - x_np, - y_np, - res, - ) = exe.run(fetch_list=[x, y, out]) - real_result = callback(x_np, y_np) - self.assertEqual((res == real_result).all(), True) - - def test_zero_dim_api_3(self): - paddle.enable_static() - with program_guard(Program(), Program()): - x = paddle.randint(-3, 3, shape=[], dtype='int32') - y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') - op = eval("paddle.%s" % (self.op_type)) - out = op(x, y) - exe = paddle.static.Executor(self.place) - ( - x_np, - y_np, - res, - ) = exe.run(fetch_list=[x, y, out]) - real_result = callback(x_np, y_np) - self.assertEqual((res == real_result).all(), True) - def test_broadcast_api_1(self): paddle.enable_static() with program_guard(Program(), Program()): @@ -383,6 +335,54 @@ def create_paddle_case(op_type, callback): ) self.assertEqual((res == real_result).all(), True) + def test_zero_dim_api_1(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.randint(-3, 3, shape=[], dtype='int32') + y = paddle.randint(-3, 3, shape=[], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + ( + x_np, + y_np, + res, + ) = exe.run(fetch_list=[x, y, out]) + real_result = callback(x_np, y_np) + self.assertEqual((res == real_result).all(), True) + + def test_zero_dim_api_2(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') + y = paddle.randint(-3, 3, shape=[], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + ( + x_np, + y_np, + res, + ) = exe.run(fetch_list=[x, y, out]) + real_result = callback(x_np, y_np) + self.assertEqual((res == real_result).all(), True) + + def test_zero_dim_api_3(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.randint(-3, 3, shape=[], dtype='int32') + y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + ( + x_np, + y_np, + res, + ) = exe.run(fetch_list=[x, y, out]) + real_result = callback(x_np, y_np) + self.assertEqual((res == real_result).all(), True) + def test_bool_api_4(self): paddle.enable_static() with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py index 8e7fb2eb3421de17be33c30f3285197ec78cff54..7fbe1f6ccf7375ec8bb58d92d7185976a076147a 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py @@ -100,6 +100,27 @@ class XPUTestLessThanOP(XPUOpTestWrapper): self.x_shape = [128, 128, 512] self.y_shape = [128, 128, 512] + class LessThanOpTestCase_ZeroDim1(LessThanOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [] + + class LessThanOpTestCase_ZeroDim2(LessThanOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [11, 17] + + class LessThanOpTestCase_ZeroDim3(LessThanOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [11, 17] + self.y_shape = [] + support_types = get_xpu_op_support_types('less_than') for stype in support_types: @@ -152,6 +173,27 @@ class XPUTestLessEqualOp(XPUOpTestWrapper): self.x_shape = [128, 128, 512] self.y_shape = [128, 128, 512] + class LessEqualOpTestCase_ZeroDim1(LessEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [] + + class LessEqualOpTestCase_ZeroDim2(LessEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [11, 17] + + class LessEqualOpTestCase_ZeroDim3(LessEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [11, 17] + self.y_shape = [] + support_types = get_xpu_op_support_types('less_equal') for stype in support_types: @@ -204,6 +246,27 @@ class XPUTestGreaterThanOp(XPUOpTestWrapper): self.x_shape = [10, 10, 20, 20] self.y_shape = [10, 10, 20, 20] + class GreaterThanOpTestCase_ZeroDim1(GreaterThanOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [] + + class GreaterThanOpTestCase_ZeroDim2(GreaterThanOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [11, 17] + + class GreaterThanOpTestCase_ZeroDim3(GreaterThanOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [11, 17] + self.y_shape = [] + support_types = get_xpu_op_support_types('greater_than') for stype in support_types: @@ -256,6 +319,27 @@ class XPUTestGreaterEqualOp(XPUOpTestWrapper): self.x_shape = [10, 30, 15] self.y_shape = [10, 30, 15] + class GreaterEqualOpTestCase_ZeroDim1(GreaterEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [] + + class GreaterEqualOpTestCase_ZeroDim2(GreaterEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [11, 17] + + class GreaterEqualOpTestCase_ZeroDim3(GreaterEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [11, 17] + self.y_shape = [] + support_types = get_xpu_op_support_types('greater_equal') for stype in support_types: @@ -308,6 +392,27 @@ class XPUTestEqualOp(XPUOpTestWrapper): self.x_shape = [11, 17] self.y_shape = [1] + class EqualOpTestCase_ZeroDim1(EqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [] + + class EqualOpTestCase_ZeroDim2(EqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [11, 17] + + class EqualOpTestCase_ZeroDim3(EqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [11, 17] + self.y_shape = [] + support_types = get_xpu_op_support_types('equal') for stype in support_types: @@ -360,6 +465,27 @@ class XPUTestNotEqualOp(XPUOpTestWrapper): self.x_shape = [512, 128] self.y_shape = [512, 128] + class NotEqualOpTestCase_ZeroDim1(NotEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [] + + class NotEqualOpTestCase_ZeroDim2(NotEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [] + self.y_shape = [11, 17] + + class NotEqualOpTestCase_ZeroDim3(NotEqualOpTestCase1): + def set_data(self): + self.lbound = -100 + self.hbound = 100 + self.x_shape = [11, 17] + self.y_shape = [] + support_types = get_xpu_op_support_types('not_equal') for stype in support_types: