未验证 提交 f287b1e9 编写于 作者: Y yeliang2258 提交者: GitHub

[Zero-Dim] support input 0D Tensor for equal_all (#49845)

* add zero dims test

* update code

* fix zero dims

* update code
上级 8e5ed04d
......@@ -381,7 +381,11 @@ void CompareAllInferMeta(const MetaTensor& x,
errors::InvalidArgument(
"The size of dim_y should not be greater than dim_x's."));
out->share_lod(x);
out->set_dims(make_ddim({1}));
if (!x.dims().size() || !y.dims().size()) {
out->set_dims(make_ddim({}));
} else {
out->set_dims(make_ddim({1}));
}
out->set_dtype(DataType::BOOL);
}
......
......@@ -1255,6 +1255,13 @@ class TestSundryAPI(unittest.TestCase):
y = paddle.full([], 0.6)
self.assertFalse(paddle.allclose(x, y))
def test_equalall(self):
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
out = paddle.equal_all(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
def test_where(self):
x1 = paddle.full([], 1)
x2 = paddle.full([], 2)
......
......@@ -780,6 +780,13 @@ class TestSundryAPI(unittest.TestCase):
y = paddle.full([], 0.6)
self.assertFalse(paddle.allclose(x, y))
def test_equalall(self):
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
out = paddle.equal_all(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册