From aec1e4ce906baee35e8444c6aedd308ff887ad81 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Wed, 8 Feb 2023 12:51:44 +0800 Subject: [PATCH] [Zero-Dim] Fix 0d axis support for argmin/argmax (#50293) --- paddle/phi/infermeta/unary.cc | 7 ++++--- .../paddle/fluid/tests/unittests/test_zero_dim_tensor.py | 8 ++++---- .../fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 39ea06c89e..e1aebcfbce 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -161,6 +161,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, const auto& x_dims = x.dims(); auto x_rank = x.dims().size(); + auto zero_dim_tensor = x_rank == 0; if (x_rank > 0) { PADDLE_ENFORCE_GE(int_axis, -x_rank, @@ -178,11 +179,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x, x_rank)); } else { // 0-dim tensor - PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten, + PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1, true, phi::errors::InvalidArgument( "'axis'(%d) must be 0 or -1 if input tensor is " - "0-dim. and flatten should be true.", + "0-dim.", int_axis)); } @@ -191,7 +192,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, if (config.is_runtime) { if (dtype == phi::TransToProtoVarType(DataType::INT32)) { int64_t all_element_num = 0; - if (flatten) { + if (flatten || zero_dim_tensor) { all_element_num = phi::product(x_dims); } else { all_element_num = x_dims[int_axis]; diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 190f0eaac8..94740925ea 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -571,7 +571,7 @@ class TestSundryAPI(unittest.TestCase): paddle.disable_static() self.x = paddle.rand([]) - def _test_argmin(self): + def test_argmin(self): x = paddle.rand([]) out1 = paddle.argmin(x, 0) out2 = paddle.argmin(x, -1) @@ -585,7 +585,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out3.shape, []) np.testing.assert_allclose(out3, 0.0) - def _test_argmax(self): + def test_argmax(self): x = paddle.rand([]) out1 = paddle.argmax(x, 0) out2 = paddle.argmax(x, -1) @@ -1641,7 +1641,7 @@ class TestSundryAPIStatic(unittest.TestCase): self.exe = paddle.static.Executor() @prog_scope() - def _test_argmin(self): + def test_argmin(self): x = paddle.rand([]) out1 = paddle.argmin(x, 0) out2 = paddle.argmin(x, -1) @@ -1664,7 +1664,7 @@ class TestSundryAPIStatic(unittest.TestCase): np.testing.assert_allclose(res[2], 0.0) @prog_scope() - def _test_argmax(self): + def test_argmax(self): x = paddle.rand([]) out1 = paddle.argmax(x, 0) out2 = paddle.argmax(x, -1) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 573dbb1547..518b1c1488 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -328,7 +328,7 @@ class TestSundryAPI(unittest.TestCase): paddle.disable_static() self.x = paddle.rand([]) - def _test_argmin(self): + def test_argmin(self): x = paddle.rand([]) out1 = paddle.argmin(x, 0) out2 = paddle.argmin(x, -1) @@ -342,7 +342,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out3.shape, []) np.testing.assert_allclose(out3, 0.0) - def _test_argmax(self): + def test_argmax(self): x = paddle.rand([]) out1 = paddle.argmax(x, 0) out2 = paddle.argmax(x, -1) -- GitLab