diff --git a/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc b/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc index 52a98c63f48987a348e0df4e15d4469f7402f7c3..8e2f56adfa14147c240949c9d1f483098037cc6b 100644 --- a/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc @@ -33,6 +33,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx, auto mask_shape = phi::vectorize(mask.dims()); auto xshape = phi::vectorize(x_grad->dims()); + if (mask.dims().size() == 0) { + mask_shape = std::vector({1}); + } + if (x_grad->dims().size() == 0) { + xshape = std::vector({1}); + } int r = xpu::masked_select_grad(dev_ctx.x_context(), input_data, diff --git a/paddle/phi/kernels/xpu/masked_select_kernel.cc b/paddle/phi/kernels/xpu/masked_select_kernel.cc index 0f142e852a9a7fd9f77ff10fc439938b89c7b3c4..c572b5c6e4eb741a5b08edac13c227cfaa479663 100644 --- a/paddle/phi/kernels/xpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/xpu/masked_select_kernel.cc @@ -61,6 +61,12 @@ void MaskedSelectKernel(const Context& dev_ctx, auto input_shape = vectorize(input_dim); auto mask_shape = vectorize(mask_dim); + if (input_dim.size() == 0) { + input_shape = std::vector({1}); + } + if (mask_dim.size() == 0) { + mask_shape = std::vector({1}); + } if (out_size_cpu > 0) { PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(),