From 1a8be15845d2b721456ab9954eb7bfb2b80704e7 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 18 Jan 2023 11:39:59 +0800 Subject: [PATCH] [Zero-Dim] Fix bug in masked_select for XPU (#49904) --- paddle/phi/kernels/xpu/masked_select_grad_kernel.cc | 6 ++++++ paddle/phi/kernels/xpu/masked_select_kernel.cc | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc b/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc index 52a98c63f48..8e2f56adfa1 100644 --- a/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc @@ -33,6 +33,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx, auto mask_shape = phi::vectorize(mask.dims()); auto xshape = phi::vectorize(x_grad->dims()); + if (mask.dims().size() == 0) { + mask_shape = std::vector({1}); + } + if (x_grad->dims().size() == 0) { + xshape = std::vector({1}); + } int r = xpu::masked_select_grad(dev_ctx.x_context(), input_data, diff --git a/paddle/phi/kernels/xpu/masked_select_kernel.cc b/paddle/phi/kernels/xpu/masked_select_kernel.cc index 0f142e852a9..c572b5c6e4e 100644 --- a/paddle/phi/kernels/xpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/xpu/masked_select_kernel.cc @@ -61,6 +61,12 @@ void MaskedSelectKernel(const Context& dev_ctx, auto input_shape = vectorize(input_dim); auto mask_shape = vectorize(mask_dim); + if (input_dim.size() == 0) { + input_shape = std::vector({1}); + } + if (mask_dim.size() == 0) { + mask_shape = std::vector({1}); + } if (out_size_cpu > 0) { PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), -- GitLab