未验证 提交 3355c0c0 编写于 作者: H houj04 提交者: GitHub

[XPU] fix gather_nd op when index's numel is 0. (#54714)

上级 1375b3f7
......@@ -34,14 +34,42 @@ void GatherNdGradKernel(const Context &ctx,
ctx.x_context(), dx_data, x_grad->numel(), static_cast<T>(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
if (out_grad.numel() == 0) return;
if (out_grad.numel() == 0) {
return;
}
if (index.numel() == 0) {
r = xpu::copy(ctx.x_context(),
out_grad.data<T>(),
x_grad->data<T>(),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
PADDLE_ENFORCE_EQ(
end_size,
0,
phi::errors::InvalidArgument("end_size[%d] should be 0", end_size));
// remain dim
auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
int64_t x_numel = x.numel();
int64_t out_grad_numel = out_grad.numel();
PADDLE_ENFORCE_EQ(
x_numel * remain_numel,
out_grad_numel,
phi::errors::InvalidArgument(
"x_numel[%d] * remain_numel[%d] should match out_grad_numel[%d]",
x_numel,
remain_numel,
out_grad_numel));
// int reduce_sum(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& rdims)
int r = xpu::reduce_sum(ctx.x_context(),
out_grad.data<T>(),
x_grad->data<T>(),
{remain_numel, x_numel},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
return;
}
......
......@@ -25,18 +25,47 @@ void GatherNdKernel(const Context &ctx,
const DenseTensor &index,
DenseTensor *out) {
ctx.template Alloc<T>(out);
const auto &index_type = index.dtype();
if (x.numel() == 0) return;
if (x.numel() == 0) {
return;
}
if (index.numel() == 0) {
out->Resize(x.dims());
ctx.template Alloc<T>(out);
int r = xpu::copy(ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
PADDLE_ENFORCE_EQ(
end_size,
0,
phi::errors::InvalidArgument("end_size[%d] should be 0", end_size));
// remain dim
auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
int64_t x_numel = x.numel();
int64_t y_numel = out->numel();
PADDLE_ENFORCE_EQ(
x_numel * remain_numel,
y_numel,
phi::errors::InvalidArgument(
"x_numel[%d] * remain_numel[%d] should match y_numel[%d]",
x_numel,
remain_numel,
y_numel));
// int broadcast(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& yshape)
int r = xpu::broadcast(ctx.x_context(),
x.data<T>(),
out->data<T>(),
{1, x_numel},
{remain_numel, x_numel});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
return;
}
const auto &index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册