diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 39ea06c89e6ba4ceb9a9552ed910ad869cff6fa5..e1aebcfbcece05343b97bc7651d1dc76a4d927d8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -161,6 +161,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, const auto& x_dims = x.dims(); auto x_rank = x.dims().size(); + auto zero_dim_tensor = x_rank == 0; if (x_rank > 0) { PADDLE_ENFORCE_GE(int_axis, -x_rank, @@ -178,11 +179,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x, x_rank)); } else { // 0-dim tensor - PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten, + PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1, true, phi::errors::InvalidArgument( "'axis'(%d) must be 0 or -1 if input tensor is " - "0-dim. and flatten should be true.", + "0-dim.", int_axis)); } @@ -191,7 +192,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, if (config.is_runtime) { if (dtype == phi::TransToProtoVarType(DataType::INT32)) { int64_t all_element_num = 0; - if (flatten) { + if (flatten || zero_dim_tensor) { all_element_num = phi::product(x_dims); } else { all_element_num = x_dims[int_axis]; diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 190f0eaac8911ca035c3388fa6cfcde759d042ed..94740925ea21d8d67933d2730799ae0f8f8ef3fa 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -571,7 +571,7 @@ class TestSundryAPI(unittest.TestCase): paddle.disable_static() self.x = paddle.rand([]) - def _test_argmin(self): + def test_argmin(self): x = paddle.rand([]) out1 = paddle.argmin(x, 0) out2 = paddle.argmin(x, -1) @@ -585,7 +585,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out3.shape, []) np.testing.assert_allclose(out3, 0.0) - def _test_argmax(self): + def test_argmax(self): x = paddle.rand([]) out1 = paddle.argmax(x, 0) out2 = paddle.argmax(x, -1) @@ -1641,7 +1641,7 @@ class TestSundryAPIStatic(unittest.TestCase): self.exe = paddle.static.Executor() @prog_scope() - def _test_argmin(self): + def test_argmin(self): x = paddle.rand([]) out1 = paddle.argmin(x, 0) out2 = paddle.argmin(x, -1) @@ -1664,7 +1664,7 @@ class TestSundryAPIStatic(unittest.TestCase): np.testing.assert_allclose(res[2], 0.0) @prog_scope() - def _test_argmax(self): + def test_argmax(self): x = paddle.rand([]) out1 = paddle.argmax(x, 0) out2 = paddle.argmax(x, -1) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 573dbb15476b6ba4493726a81743c6b655cc7bce..518b1c1488e17a8adf6b6d202f13ef0589da5bf4 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -328,7 +328,7 @@ class TestSundryAPI(unittest.TestCase): paddle.disable_static() self.x = paddle.rand([]) - def _test_argmin(self): + def test_argmin(self): x = paddle.rand([]) out1 = paddle.argmin(x, 0) out2 = paddle.argmin(x, -1) @@ -342,7 +342,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out3.shape, []) np.testing.assert_allclose(out3, 0.0) - def _test_argmax(self): + def test_argmax(self): x = paddle.rand([]) out1 = paddle.argmax(x, 0) out2 = paddle.argmax(x, -1)