From ee4e5323d28eaac69cb14a228b4aff3e8b22db57 Mon Sep 17 00:00:00 2001 From: sprouteer <89541335+sprouteer@users.noreply.github.com> Date: Fri, 20 Jan 2023 11:52:44 +0800 Subject: [PATCH] add unique support zero dim (#49260) --- paddle/phi/infermeta/unary.cc | 21 +++++++--- .../tests/unittests/test_zero_dim_tensor.py | 40 +++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 5a7b2cf16a..55e895c662 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4596,17 +4596,26 @@ void UniqueRawInferMeta(const MetaTensor& x, MetaTensor* index, MetaTensor* counts) { if (!is_sorted) { - PADDLE_ENFORCE_EQ( - x.dims().size(), - 1, - phi::errors::InvalidArgument("The Input(X) should be 1-D Tensor, " - "But now the dims of Input(X) is %d.", - x.dims().size())); + PADDLE_ENFORCE_EQ(x.dims().size() == 1 || x.dims().size() == 0, + true, + phi::errors::InvalidArgument( + "The Input(X) should be 0-D or 1-D Tensor, " + "But now the dims of Input(X) is %d.", + x.dims().size())); out->set_dims(phi::make_ddim({-1})); index->set_dims(x.dims()); return; } + if (x.dims().size() == 0) { + PADDLE_ENFORCE_EQ(axis.empty(), + true, + phi::errors::InvalidArgument( + "The Input(X) with 0-D Tensor, axis must be None" + "But now the axis is %d.", + axis[0])); + } + if (axis.empty()) { out->set_dims(phi::make_ddim({-1})); if (return_inverse) { diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 7a290beab7..11d85b5244 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2445,6 +2445,29 @@ class TestNoBackwardAPI(unittest.TestCase): self.assertEqual(one_hot_label.shape, [4]) self.assertEqual(one_hot_label.numpy()[2], 1) + def test_unique(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + x = paddle.rand([]) + y, index, inverse, counts = paddle.unique( + x, + return_index=True, + return_inverse=True, + return_counts=True, + ) + + self.assertEqual(y, x) + self.assertEqual(index, 0) + self.assertEqual(inverse, 0) + self.assertEqual(counts, 1) + self.assertEqual(y.shape, [1]) + self.assertEqual(index.shape, [1]) + self.assertEqual(inverse.shape, [1]) + self.assertEqual(counts.shape, [1]) + class TestNoBackwardAPIStatic(unittest.TestCase): def setUp(self): @@ -2647,6 +2670,23 @@ class TestNoBackwardAPIStatic(unittest.TestCase): self.assertEqual(res[0].shape, (4,)) self.assertEqual(res[0][2], 1) + def test_unique(self): + x = paddle.rand([]) + y, index, inverse, counts = paddle.unique( + x, return_index=True, return_inverse=True, return_counts=True + ) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[y, index, inverse, counts]) + self.assertEqual(y, x) + self.assertEqual(index, 0) + self.assertEqual(inverse, 0) + self.assertEqual(counts, 1) + self.assertEqual(res[0].shape, (1,)) + self.assertEqual(res[1].shape, (1,)) + self.assertEqual(res[2].shape, (1,)) + self.assertEqual(res[3].shape, (1,)) + unary_apis_with_complex_input = [ paddle.real, -- GitLab