index_select_grad_kernel.cu 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/index_select_grad_kernel.h"

17
#include "paddle/phi/backends/gpu/gpu_info.h"
18
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
W
Wang Xin 已提交
19
#include "paddle/phi/backends/gpu/gpu_primitives.h"
20 21
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
22
#include "paddle/phi/kernels/funcs/math_function.h"
23

24 25
DECLARE_bool(cudnn_deterministic);

26 27
namespace phi {

W
Wang Xin 已提交
28
using phi::PADDLE_CUDA_NUM_THREADS;
29 30 31 32 33 34 35 36 37

template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
                                              T* input_grad,
                                              const IndexT* index,
                                              int64_t N,
                                              int64_t stride,
                                              int64_t size,
                                              int64_t delta) {
38
  CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
39 40 41 42 43
    int64_t pre_idx = idx / (stride * size);
    int64_t dim_idx = idx % (stride * size) / stride;
    IndexT src_dim_idx = index[dim_idx];
    int64_t input_idx =
        idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
W
Wang Xin 已提交
44
    phi::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  }
}

template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
                           const DenseTensor& x,
                           const DenseTensor& index,
                           const DenseTensor& out_grad,
                           int dim,
                           DenseTensor* x_grad) {
  auto* output_grad_data = out_grad.data<T>();
  auto* in_grad_data = ctx.template Alloc<T>(x_grad);

  auto input_dim = x_grad->dims();
  auto output_dim = out_grad.dims();
  dim = dim >= 0 ? dim : dim + input_dim.size();
  auto stride_dim = phi::stride(input_dim);
  int64_t stride = stride_dim[dim];
  int64_t size = output_dim[dim];
  int64_t delta = input_dim[dim] - size;
  const auto& index_type = index.dtype();

  bool index_type_match =
      index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
  PADDLE_ENFORCE_EQ(index_type_match,
                    true,
                    phi::errors::InvalidArgument(
                        "Input(Index) holds the wrong type, it holds %s, but "
                        "desires to be %s or %s",
                        index_type,
                        phi::DataType::INT32,
                        phi::DataType::INT64));

  int64_t numel = x_grad->numel();
79 80 81
  if (numel == 0) {
    return;
  }
82 83 84 85 86
  int64_t index_nums = index.numel();
  int64_t out_nums = out_grad.numel();

  auto stream = ctx.stream();

87 88
  unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
  dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
89
  phi::backends::gpu::LimitGridDim(ctx, &grid_dim);
90

91 92
  phi::funcs::SetConstant<phi::GPUContext, T> index_select_grad_init;
  index_select_grad_init(ctx, x_grad, static_cast<T>(0));
93 94 95

  if (FLAGS_cudnn_deterministic) {
    VLOG(2) << "Run grad kernel of index_select with single thread.";
96 97
    block_dim = 1;
    grid_dim.x = 1;
98 99
  }

100 101
  if (index_type == phi::DataType::INT64) {
    const int64_t* index_data = index.data<int64_t>();
102 103 104 105 106 107 108 109
    index_select_grad_cuda_kernel<T, int64_t>
        <<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
                                             in_grad_data,
                                             index_data,
                                             out_nums,
                                             stride,
                                             size,
                                             delta);
110 111
  } else {
    const int* index_data = index.data<int>();
112 113 114 115 116 117 118 119
    index_select_grad_cuda_kernel<T, int>
        <<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
                                             in_grad_data,
                                             index_data,
                                             out_nums,
                                             stride,
                                             size,
                                             delta);
120 121 122 123 124 125 126 127 128 129 130 131
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(index_select_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::IndexSelectGradKernel,
                   float,
                   double,
                   phi::dtype::float16,
132
                   phi::dtype::bfloat16,
133 134
                   int,
                   int64_t) {}