diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 002af75c04c1facb11f83d9c2f29374af7d97c41..561938adca80a22cc3700baab3dc58c8bf9a6321 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -381,7 +381,11 @@ void CompareAllInferMeta(const MetaTensor& x, errors::InvalidArgument( "The size of dim_y should not be greater than dim_x's.")); out->share_lod(x); - out->set_dims(make_ddim({1})); + if (!x.dims().size() || !y.dims().size()) { + out->set_dims(make_ddim({})); + } else { + out->set_dims(make_ddim({1})); + } out->set_dtype(DataType::BOOL); } diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py old mode 100644 new mode 100755 index eedf4ae596cab6cc26aa279427cf1ce188bcdb05..8c420bf65cce7482fa7f129f469fde6ba4eabf10 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1255,6 +1255,13 @@ class TestSundryAPI(unittest.TestCase): y = paddle.full([], 0.6) self.assertFalse(paddle.allclose(x, y)) + def test_equalall(self): + x = paddle.full([], 0.5) + y = paddle.full([], 0.6) + out = paddle.equal_all(x, y) + self.assertEqual(out.shape, []) + self.assertFalse(out) + def test_where(self): x1 = paddle.full([], 1) x2 = paddle.full([], 2) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py old mode 100644 new mode 100755 index dfa04456bf8b993043245f4c9242ea475c93aa74..2cedf9e60bd0c030b2910b7cff2dcd8514a63326 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -780,6 +780,13 @@ class TestSundryAPI(unittest.TestCase): y = paddle.full([], 0.6) self.assertFalse(paddle.allclose(x, y)) + def test_equalall(self): + x = paddle.full([], 0.5) + y = paddle.full([], 0.6) + out = paddle.equal_all(x, y) + self.assertEqual(out.shape, []) + self.assertFalse(out) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.