gather.cu.h 11.4 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16

17
#include <vector>
18

19 20
#include "paddle/fluid/memory/memcpy.h"
// TODO(paddle-dev): move gpu_primitives.h to phi
21
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
23 24 25
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
26
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zchen0211 已提交
27

28 29
namespace phi {
namespace funcs {
Z
zchen0211 已提交
30

31
template <typename T, typename IndexT = int>
32 33 34 35
__global__ void GatherCUDAKernel(const T* params,
                                 const IndexT* indices,
                                 T* output,
                                 size_t index_size,
36
                                 size_t slice_size) {
37 38 39
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
40
    IndexT gather_i = indices[indices_i];
Z
Zeng Jinle 已提交
41
    int64_t params_i = gather_i * slice_size + slice_i;
Z
zchen0211 已提交
42 43 44 45
    *(output + i) = *(params + params_i);
  }
}

46
template <typename T, typename IndexT = int>
47
__global__ void GatherNdCUDAKernel(const T* input,
48
                                   const Dim<DDim::kMaxRank> input_dims,
49 50 51 52
                                   const IndexT* indices,
                                   T* output,
                                   size_t remain_size,
                                   size_t slice_size,
53
                                   size_t end_size) {
54 55 56
  CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
Z
Zeng Jinle 已提交
57
    int64_t gather_i = 0;
58 59 60
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      auto index_value = indices[indices_i * end_size + j];
61 62 63 64 65
      PADDLE_ENFORCE(
          index_value >= 0 && index_value < input_dims[j],
          "The index is out of bounds, "
          "please check whether the dimensions of index and "
          "input meet the requirements. It should "
66
          "be less than [%d] and greater than or equal to 0, but received [%d]",
67 68
          input_dims[j],
          index_value);
69 70 71
      gather_i += (index_value * temp);
      temp *= input_dims[j];
    }
Z
Zeng Jinle 已提交
72
    int64_t input_i = gather_i + slice_i;
73 74 75 76
    *(output + i) = *(input + input_i);
  }
}

Z
zchen0211 已提交
77 78 79 80
/**
 * A thin wrapper on gpu tensor
 * Return a new tensor from source tensor, gathered according to index
 * input[src]: type-T source Tensor
81
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
82 83
 * return: output tensor
 */
84
template <typename T, typename IndexT = int>
85 86 87 88
void GPUGather(const phi::GPUContext& ctx,
               const DenseTensor& src,
               const DenseTensor& index,
               DenseTensor* output) {
Z
Zeng Jinle 已提交
89
  if (index.dims().size() == 2) {
90 91 92 93 94
    PADDLE_ENFORCE_EQ(
        index.dims()[1],
        1,
        phi::errors::InvalidArgument("If the index's rank of gather_op is 2,"
                                     " the second dimension should be 1."));
C
chengduo 已提交
95
  }
Y
Yibing Liu 已提交
96

97
  // index size
98
  int64_t index_size = index.dims()[0];
Z
Zeng Jinle 已提交
99
  if (index_size == 0) return;
Z
zchen0211 已提交
100

101
  auto src_dims = src.dims();
102
  phi::DDim output_dims(src_dims);
Z
zchen0211 已提交
103 104 105
  output_dims[0] = index_size;

  // slice size
106
  int64_t slice_size = 1;
Z
zchen0211 已提交
107 108
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

109
  const T* p_src = src.data<T>();
110
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
111 112 113
  T* p_output = output->data<T>();

  int block = 512;
114
  int64_t n = slice_size * index_size;
115 116
  dim3 grid = dim3((n + block - 1) / block);
  paddle::platform::LimitGridDim(ctx, &grid);
Z
zchen0211 已提交
117

118
  GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
119
      p_src, p_index, p_output, index_size, slice_size);
Z
zchen0211 已提交
120 121
}

122 123 124 125 126
template <typename T, typename IndexT = int>
void GPUGatherNd(const phi::GPUContext& ctx,
                 const DenseTensor& input,
                 const DenseTensor& index,
                 DenseTensor* output) {
127
  const auto gplace = ctx.GetPlace();
128
  auto cplace = phi::CPUPlace();
129 130 131 132 133 134 135 136 137 138 139 140 141

  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();
  auto input_dims = input.dims();
  auto input_dims_size = input_dims.size();

  const T* p_input = input.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
142 143
  auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = phi::product(remain_ddim);
144 145 146 147 148 149
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < input_dims_size; ++i) {
    slice_size *= input_dims[i];
  }
  // source dim
150
  Dim<DDim::kMaxRank> g_input_dims;
151
  for (int i = 0; i < input_dims_size; ++i) {
152
    g_input_dims[i] = input_dims[i];
153 154 155
  }

  int block = 512;
156
  int64_t n = slice_size * remain_numel;
157 158
  dim3 grid = dim3((n + block - 1) / block);
  paddle::platform::LimitGridDim(ctx, &grid);
159

160 161 162 163 164 165 166
  GatherNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_input,
                                                                  g_input_dims,
                                                                  p_index,
                                                                  p_output,
                                                                  remain_numel,
                                                                  slice_size,
                                                                  end_size);
167 168
}

169
template <typename T, typename U>
170 171 172 173 174
__global__ void GatherGPUKernel(const T* input,
                                const U* index,
                                T* out,
                                int64_t outer_dim_size,
                                int64_t inner_dim_size,
175
                                int64_t out_index_dim_size,
176 177
                                int64_t input_index_dim_size,
                                int64_t size) {
178 179
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  int64_t outer_size = outer_dim_size * out_index_dim_size;
180
  for (; idx < size; idx += blockDim.x * gridDim.x) {
181 182 183 184
    int64_t inner_dim_index = idx / outer_size;
    int64_t next_idx = idx - outer_size * inner_dim_index;
    int64_t index_dim_index = next_idx / outer_dim_size;
    U index_val = index[index_dim_index];
185 186 187 188 189 190 191

    PADDLE_ENFORCE(
        index_val >= 0 && index_val < input_index_dim_size,
        "The index is out of bounds, "
        "please check whether the dimensions of index and "
        "input meet the requirements. It should "
        "be less than [%d] and greater than or equal to 0, but received [%d]",
192 193
        input_index_dim_size,
        index_val);
194

195 196
    int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index;
    int64_t input_index =
197
        inner_dim_index * (outer_dim_size * input_index_dim_size) +
198
        index_val * outer_dim_size + out_dim_index;
199 200 201 202 203
    out[idx] = input[input_index];
  }
}

template <typename T, typename U>
204 205 206
__global__ void GatherGradGPUKernel(const T* input,
                                    const U* index,
                                    T* out,
207 208 209
                                    int64_t outer_dim_size,
                                    int64_t inner_dim_size,
                                    int64_t input_index_dim_size,
210 211
                                    int64_t out_index_dim_size,
                                    int64_t size) {
212
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
213
  for (; idx < size; idx += blockDim.x * gridDim.x) {
214 215 216 217 218 219 220
    int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
    int64_t next_idx = idx % (outer_dim_size * input_index_dim_size);
    int64_t index_dim_index = next_idx / (outer_dim_size);
    int64_t out_dim_index = next_idx % outer_dim_size;
    int64_t out_index =
        inner_dim_index * (outer_dim_size * out_index_dim_size) +
        index[index_dim_index] * outer_dim_size + out_dim_index;
221 222 223 224
    paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
  }
}

225
template <typename T, typename U>
226 227 228 229 230
void GatherV2CUDAFunction(const DenseTensor* input,
                          const DenseTensor* index,
                          const int axis,
                          DenseTensor* out,
                          const phi::GPUContext& ctx) {
231 232
  int64_t index_size = index->numel();
  int64_t input_size = input->numel();
233 234 235 236 237
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();
  auto* index_data = index->data<U>();

  if (input->numel() == 0) return;
238 239

  int axis_index = axis;
240
  int64_t index_dim_size = input_dim[axis_index];
241

242 243 244
  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
  std::vector<int64_t> out_dim_vec;
245 246 247 248 249 250 251 252 253 254

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
  out_dim_vec.push_back(index_size);
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
255
  auto out_dim = phi::make_ddim(out_dim_vec);
256 257

  out->Resize(out_dim);
258
  auto* out_data = ctx.Alloc<T>(out);
259
  int64_t out_size = out->numel();
Z
Zeng Jinle 已提交
260
  if (out_size == 0) return;
261

262 263
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size);
  auto stream = ctx.stream();
264 265 266 267 268 269 270 271 272 273
  GatherGPUKernel<T, U>
      <<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
          input_data,
          index_data,
          out_data,
          outer_dim_size,
          inner_dim_size,
          index_size,
          index_dim_size,
          out_size);
274 275
}

276
template <typename T, typename U>
277 278 279 280 281
void GatherV2GradCUDAFunction(const DenseTensor* input,
                              const DenseTensor* index,
                              const int axis,
                              DenseTensor* out,
                              const phi::GPUContext& ctx) {
282
  auto* index_data = index->data<U>();
283 284
  int64_t index_size = index->numel();
  int64_t input_size = input->numel();
285 286 287 288
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();

  if (input->numel() == 0) return;
289
  int axis_index = axis;
290
  int64_t input_index_dim_size = input_dim[axis_index];
291

292 293
  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
294 295 296 297 298 299 300 301

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
  }
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
  }

302
  auto* out_data = ctx.Alloc<T>(out);
303
  auto out_dim = out->dims();
304
  int64_t out_index_dim_size = out_dim[axis_index];
305
  phi::funcs::set_constant(ctx, out, 0.0);
306

307 308
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, input_size);
  auto stream = ctx.stream();
309 310 311 312 313 314 315 316 317 318
  GatherGradGPUKernel<T, U>
      <<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
          input_data,
          index_data,
          out_data,
          outer_dim_size,
          inner_dim_size,
          input_index_dim_size,
          out_index_dim_size,
          input_size);
319
}
320 321 322

}  // namespace funcs
}  // namespace phi