未验证 提交 e4e94a88 编写于 作者: Z Zhong Hui 提交者: GitHub

[Zero-Dim] Fix 0-dim tensor for arg_min_max op. (#49570)

* fix 0-d tensor for arg_min_max op.

* fix xpu.

* fix zero dims

* fix

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update test_zero_dim_tensor.py

* Update test_zero_dim_tensor_xpu.py

* Update test_zero_dim_tensor.py

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc
上级 71f247b1
...@@ -160,22 +160,34 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -160,22 +160,34 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto int_axis = axis.to<int64_t>(); auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE( auto x_rank = x.dims().size();
int_axis, if (x_rank > 0) {
-x_dims.size(), PADDLE_ENFORCE_GE(int_axis,
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" -x_rank,
" -Rank(X)(%d).", phi::errors::InvalidArgument(
int_axis, "'axis'(%d) must be greater than or equal to"
-x_dims.size())); " -Rank(X)(%d).",
PADDLE_ENFORCE_LT(int_axis, int_axis,
x_dims.size(), -x_rank));
phi::errors::InvalidArgument( PADDLE_ENFORCE_LT(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).", int_axis,
int_axis, x_rank,
x_dims.size())); phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_rank));
} else {
// 0-dim tensor
PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten,
true,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim. and flatten should be true.",
int_axis));
}
auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank; if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) { if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) { if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0; int64_t all_element_num = 0;
...@@ -195,8 +207,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -195,8 +207,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
INT_MAX)); INT_MAX));
} }
} }
std::vector<int64_t> vec; std::vector<int64_t> vec;
if (flatten) {
if (x_rank == 0) {
// vec is set to empty
} else if (flatten) {
vec.emplace_back(static_cast<int64_t>(1)); vec.emplace_back(static_cast<int64_t>(1));
} else { } else {
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]); for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
...@@ -205,6 +221,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -205,6 +221,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
} }
for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
} }
out->set_dims(phi::make_ddim(vec)); out->set_dims(phi::make_ddim(vec));
if (dtype == 2) { if (dtype == 2) {
out->set_dtype(DataType::INT32); out->set_dtype(DataType::INT32);
......
...@@ -96,6 +96,12 @@ struct VisitDataArgMinMaxFunctor { ...@@ -96,6 +96,12 @@ struct VisitDataArgMinMaxFunctor {
if (axis < 0) new_axis = axis + x_dims.size(); if (axis < 0) new_axis = axis + x_dims.size();
} }
// For 0D Tensor
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \ #define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \ ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims) functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
......
...@@ -30,6 +30,7 @@ namespace cub = hipcub; ...@@ -30,6 +30,7 @@ namespace cub = hipcub;
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
namespace { // NOLINT namespace { // NOLINT
...@@ -180,6 +181,12 @@ struct VisitDataCudaArgMinMaxFunctor { ...@@ -180,6 +181,12 @@ struct VisitDataCudaArgMinMaxFunctor {
x_dims = x.dims(); x_dims = x.dims();
if (axis < 0) new_axis = axis + x.dims().size(); if (axis < 0) new_axis = axis + x.dims().size();
} }
// For 0D Tensor
if (x.dims().size() == 0) {
dev_ctx.template Alloc<IndType>(out);
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}
int64_t numel = x.numel(); int64_t numel = x.numel();
int64_t groups = numel / x_dims[new_axis]; int64_t groups = numel / x_dims[new_axis];
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -39,7 +40,15 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -39,7 +40,15 @@ void ArgMaxKernel(const Context& dev_ctx,
DataType::INT64, DataType::INT64,
DataType::INT32, DataType::INT32,
dtype)); dtype));
// TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out); dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}
DDim x_dims; DDim x_dims;
int axis_val = axis.to<int>(); int axis_val = axis.to<int>();
......
...@@ -189,6 +189,8 @@ reduce_api_list = [ ...@@ -189,6 +189,8 @@ reduce_api_list = [
paddle.logsumexp, paddle.logsumexp,
paddle.all, paddle.all,
paddle.any, paddle.any,
paddle.argmax,
paddle.argmin,
] ]
...@@ -208,12 +210,13 @@ class TestReduceAPI(unittest.TestCase): ...@@ -208,12 +210,13 @@ class TestReduceAPI(unittest.TestCase):
out.retain_grads() out.retain_grads()
out.backward() out.backward()
out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy()) if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
...@@ -250,7 +253,9 @@ class TestReduceAPI(unittest.TestCase): ...@@ -250,7 +253,9 @@ class TestReduceAPI(unittest.TestCase):
res = exe.run(main_prog, fetch_list=fetch_list) res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[0], res[1]) if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(res[0], res[1])
if len(res) > 2: if len(res) > 2:
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
......
...@@ -132,6 +132,8 @@ reduce_api_list = [ ...@@ -132,6 +132,8 @@ reduce_api_list = [
paddle.logsumexp, paddle.logsumexp,
paddle.all, paddle.all,
paddle.any, paddle.any,
paddle.argmax,
paddle.argmin,
] ]
...@@ -153,7 +155,8 @@ class TestReduceAPI(unittest.TestCase): ...@@ -153,7 +155,8 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy()) if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(out.numpy(), x.numpy())
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册