未验证 提交 aec1e4ce 编写于 作者: Z Zhong Hui 提交者: GitHub

[Zero-Dim] Fix 0d axis support for argmin/argmax (#50293)

上级 b3888614
...@@ -161,6 +161,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -161,6 +161,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
auto x_rank = x.dims().size(); auto x_rank = x.dims().size();
auto zero_dim_tensor = x_rank == 0;
if (x_rank > 0) { if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis, PADDLE_ENFORCE_GE(int_axis,
-x_rank, -x_rank,
...@@ -178,11 +179,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -178,11 +179,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
x_rank)); x_rank));
} else { } else {
// 0-dim tensor // 0-dim tensor
PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten, PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1,
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is " "'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim. and flatten should be true.", "0-dim.",
int_axis)); int_axis));
} }
...@@ -191,7 +192,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -191,7 +192,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (config.is_runtime) { if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) { if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0; int64_t all_element_num = 0;
if (flatten) { if (flatten || zero_dim_tensor) {
all_element_num = phi::product(x_dims); all_element_num = phi::product(x_dims);
} else { } else {
all_element_num = x_dims[int_axis]; all_element_num = x_dims[int_axis];
......
...@@ -571,7 +571,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -571,7 +571,7 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
self.x = paddle.rand([]) self.x = paddle.rand([])
def _test_argmin(self): def test_argmin(self):
x = paddle.rand([]) x = paddle.rand([])
out1 = paddle.argmin(x, 0) out1 = paddle.argmin(x, 0)
out2 = paddle.argmin(x, -1) out2 = paddle.argmin(x, -1)
...@@ -585,7 +585,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -585,7 +585,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out3.shape, []) self.assertEqual(out3.shape, [])
np.testing.assert_allclose(out3, 0.0) np.testing.assert_allclose(out3, 0.0)
def _test_argmax(self): def test_argmax(self):
x = paddle.rand([]) x = paddle.rand([])
out1 = paddle.argmax(x, 0) out1 = paddle.argmax(x, 0)
out2 = paddle.argmax(x, -1) out2 = paddle.argmax(x, -1)
...@@ -1641,7 +1641,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1641,7 +1641,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.exe = paddle.static.Executor() self.exe = paddle.static.Executor()
@prog_scope() @prog_scope()
def _test_argmin(self): def test_argmin(self):
x = paddle.rand([]) x = paddle.rand([])
out1 = paddle.argmin(x, 0) out1 = paddle.argmin(x, 0)
out2 = paddle.argmin(x, -1) out2 = paddle.argmin(x, -1)
...@@ -1664,7 +1664,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1664,7 +1664,7 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_allclose(res[2], 0.0) np.testing.assert_allclose(res[2], 0.0)
@prog_scope() @prog_scope()
def _test_argmax(self): def test_argmax(self):
x = paddle.rand([]) x = paddle.rand([])
out1 = paddle.argmax(x, 0) out1 = paddle.argmax(x, 0)
out2 = paddle.argmax(x, -1) out2 = paddle.argmax(x, -1)
......
...@@ -328,7 +328,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -328,7 +328,7 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
self.x = paddle.rand([]) self.x = paddle.rand([])
def _test_argmin(self): def test_argmin(self):
x = paddle.rand([]) x = paddle.rand([])
out1 = paddle.argmin(x, 0) out1 = paddle.argmin(x, 0)
out2 = paddle.argmin(x, -1) out2 = paddle.argmin(x, -1)
...@@ -342,7 +342,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -342,7 +342,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out3.shape, []) self.assertEqual(out3.shape, [])
np.testing.assert_allclose(out3, 0.0) np.testing.assert_allclose(out3, 0.0)
def _test_argmax(self): def test_argmax(self):
x = paddle.rand([]) x = paddle.rand([])
out1 = paddle.argmax(x, 0) out1 = paddle.argmax(x, 0)
out2 = paddle.argmax(x, -1) out2 = paddle.argmax(x, -1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册