未验证 提交 aec1e4ce 编写于 作者: Z Zhong Hui 提交者: GitHub

[Zero-Dim] Fix 0d axis support for argmin/argmax (#50293)

上级 b3888614
......@@ -161,6 +161,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
const auto& x_dims = x.dims();
auto x_rank = x.dims().size();
auto zero_dim_tensor = x_rank == 0;
if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis,
-x_rank,
......@@ -178,11 +179,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
x_rank));
} else {
// 0-dim tensor
PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten,
PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1,
true,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim. and flatten should be true.",
"0-dim.",
int_axis));
}
......@@ -191,7 +192,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
if (flatten) {
if (flatten || zero_dim_tensor) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[int_axis];
......
......@@ -571,7 +571,7 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static()
self.x = paddle.rand([])
def _test_argmin(self):
def test_argmin(self):
x = paddle.rand([])
out1 = paddle.argmin(x, 0)
out2 = paddle.argmin(x, -1)
......@@ -585,7 +585,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out3.shape, [])
np.testing.assert_allclose(out3, 0.0)
def _test_argmax(self):
def test_argmax(self):
x = paddle.rand([])
out1 = paddle.argmax(x, 0)
out2 = paddle.argmax(x, -1)
......@@ -1641,7 +1641,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.exe = paddle.static.Executor()
@prog_scope()
def _test_argmin(self):
def test_argmin(self):
x = paddle.rand([])
out1 = paddle.argmin(x, 0)
out2 = paddle.argmin(x, -1)
......@@ -1664,7 +1664,7 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_allclose(res[2], 0.0)
@prog_scope()
def _test_argmax(self):
def test_argmax(self):
x = paddle.rand([])
out1 = paddle.argmax(x, 0)
out2 = paddle.argmax(x, -1)
......
......@@ -328,7 +328,7 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static()
self.x = paddle.rand([])
def _test_argmin(self):
def test_argmin(self):
x = paddle.rand([])
out1 = paddle.argmin(x, 0)
out2 = paddle.argmin(x, -1)
......@@ -342,7 +342,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out3.shape, [])
np.testing.assert_allclose(out3, 0.0)
def _test_argmax(self):
def test_argmax(self):
x = paddle.rand([])
out1 = paddle.argmax(x, 0)
out2 = paddle.argmax(x, -1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册