From 81056073b80abc2fcd86da32fdeb8ee8ae082313 Mon Sep 17 00:00:00 2001 From: RuohengMa <120699764+RuohengMa@users.noreply.github.com> Date: Mon, 15 May 2023 10:59:49 +0800 Subject: [PATCH] [XPU][PHI] bind index_sample_grad xpu kernel (#53753) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 1 + .../kernels/xpu/index_sample_grad_kernel.cc | 82 +++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 paddle/phi/kernels/xpu/index_sample_grad_kernel.cc diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index ba402aed5a6..ba9742a98d0 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -420,6 +420,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, + {"index_sample_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"index_sample", XPUKernelSet({phi::DataType::INT8, phi::DataType::INT16, diff --git a/paddle/phi/kernels/xpu/index_sample_grad_kernel.cc b/paddle/phi/kernels/xpu/index_sample_grad_kernel.cc new file mode 100644 index 00000000000..22c35ef4684 --- /dev/null +++ b/paddle/phi/kernels/xpu/index_sample_grad_kernel.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_sample_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +void IndexSampleGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + DenseTensor* in_grad) { + using XPUType = typename XPUTypeTrait::Type; + const auto& index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + XPUType* in_grad_data = ctx.template Alloc(in_grad); + const XPUType* out_grad_data = out_grad.data(); + auto in_grad_shape = phi::vectorize(in_grad->dims()); + auto out_grad_shape = phi::vectorize(out_grad.dims()); + auto index_shape = phi::vectorize(index.dims()); + + int r = xpu::constant( + ctx.x_context(), in_grad_data, in_grad->numel(), static_cast(0)); + + if (index_type == phi::DataType::INT32) { + const int* index_data = index.data(); + r = xpu::scatter_element(ctx.x_context(), + in_grad_data, + out_grad_data, + index_data, + in_grad_data, + in_grad_shape, + out_grad_shape, + index_shape, + 1, + 1); + } else if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + r = xpu::scatter_element(ctx.x_context(), + in_grad_data, + out_grad_data, + index_data, + in_grad_data, + in_grad_shape, + out_grad_shape, + index_shape, + 1, + 1); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter_element"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + index_sample_grad, XPU, ALL_LAYOUT, phi::IndexSampleGradKernel, float) {} -- GitLab