diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c7b9aa157ebb9ee903a13aa536363bf252641f9b..4a84e7a2e71c78cce82cf802392f0fab32544082 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so") set(XPU_RT_LIB_NAME "libxpurt.so") set(XPU_XFT_LIB_NAME "libxft.so") -set(XPU_BASE_DATE "20230227") +set(XPU_BASE_DATE "20230308") set(XPU_XCCL_BASE_VERSION "1.0.10") set(XPU_XFT_BASE_VERSION "latest") diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index f48a91c82858a314f5a21e5421c9bd449b616b56..d0814e794446e22c05bbc419720c0bdf97293b76 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -391,6 +391,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::BOOL})}, + {"index_select_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"index_select", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, diff --git a/paddle/phi/kernels/xpu/index_select_grad_kernel.cc b/paddle/phi/kernels/xpu/index_select_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..14bfce38799f0c1c410b089b15f4f5387e238aac --- /dev/null +++ b/paddle/phi/kernels/xpu/index_select_grad_kernel.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_select_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +void IndexSelectGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + int dim, + DenseTensor* x_grad) { + if (dim < 0) { + dim += out_grad.dims().size(); + } + const auto& index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + T* x_grad_data = ctx.template Alloc(x_grad); + const T* out_grad_data = out_grad.data(); + + auto out_grad_shape = phi::vectorize(out_grad.dims()); + auto x_grad_shape = phi::vectorize(x_grad->dims()); + + int r = xpu::Error_t::SUCCESS; + if (index_type == phi::DataType::INT32) { + const int* index_data = index.data(); + r = xpu::index_select_grad(ctx.x_context(), + nullptr, + index_data, + out_grad_data, + dim, + x_grad_data, + out_grad_shape, + x_grad_shape); + } else if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + r = xpu::index_select_grad(ctx.x_context(), + nullptr, + index_data, + out_grad_data, + dim, + x_grad_data, + out_grad_shape, + x_grad_shape); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_select_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + index_select_grad, XPU, ALL_LAYOUT, phi::IndexSelectGradKernel, float) {}