gather.cu.h 11.3 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16

17
#include <vector>
18

19 20
#include "paddle/fluid/memory/memcpy.h"
// TODO(paddle-dev): move gpu_primitives.h to phi
21
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
22 23 24
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
25
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zchen0211 已提交
26

27 28
namespace phi {
namespace funcs {
Z
zchen0211 已提交
29

30
template <typename T, typename IndexT = int>
31 32 33 34
__global__ void GatherCUDAKernel(const T* params,
                                 const IndexT* indices,
                                 T* output,
                                 size_t index_size,
35
                                 size_t slice_size) {
36 37 38
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
39
    IndexT gather_i = indices[indices_i];
Z
Zeng Jinle 已提交
40
    int64_t params_i = gather_i * slice_size + slice_i;
Z
zchen0211 已提交
41 42 43 44
    *(output + i) = *(params + params_i);
  }
}

45
template <typename T, typename IndexT = int>
46
__global__ void GatherNdCUDAKernel(const T* input,
47
                                   const Dim<DDim::kMaxRank> input_dims,
48 49 50 51
                                   const IndexT* indices,
                                   T* output,
                                   size_t remain_size,
                                   size_t slice_size,
52
                                   size_t end_size) {
53 54 55
  CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
Z
Zeng Jinle 已提交
56
    int64_t gather_i = 0;
57 58 59
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      auto index_value = indices[indices_i * end_size + j];
60 61 62 63 64
      PADDLE_ENFORCE(
          index_value >= 0 && index_value < input_dims[j],
          "The index is out of bounds, "
          "please check whether the dimensions of index and "
          "input meet the requirements. It should "
65
          "be less than [%d] and greater than or equal to 0, but received [%d]",
66 67
          input_dims[j],
          index_value);
68 69 70
      gather_i += (index_value * temp);
      temp *= input_dims[j];
    }
Z
Zeng Jinle 已提交
71
    int64_t input_i = gather_i + slice_i;
72 73 74 75
    *(output + i) = *(input + input_i);
  }
}

Z
zchen0211 已提交
76 77 78 79
/**
 * A thin wrapper on gpu tensor
 * Return a new tensor from source tensor, gathered according to index
 * input[src]: type-T source Tensor
80
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
81 82
 * return: output tensor
 */
83
template <typename T, typename IndexT = int>
84 85 86 87
void GPUGather(const phi::GPUContext& ctx,
               const DenseTensor& src,
               const DenseTensor& index,
               DenseTensor* output) {
Z
Zeng Jinle 已提交
88
  if (index.dims().size() == 2) {
89 90 91 92 93
    PADDLE_ENFORCE_EQ(
        index.dims()[1],
        1,
        phi::errors::InvalidArgument("If the index's rank of gather_op is 2,"
                                     " the second dimension should be 1."));
C
chengduo 已提交
94
  }
Y
Yibing Liu 已提交
95

96
  // index size
97
  int64_t index_size = index.dims()[0];
Z
Zeng Jinle 已提交
98
  if (index_size == 0) return;
Z
zchen0211 已提交
99

100
  auto src_dims = src.dims();
101
  phi::DDim output_dims(src_dims);
Z
zchen0211 已提交
102 103 104
  output_dims[0] = index_size;

  // slice size
105
  int64_t slice_size = 1;
Z
zchen0211 已提交
106 107
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

108
  const T* p_src = src.data<T>();
109
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
110 111 112
  T* p_output = output->data<T>();

  int block = 512;
113
  int64_t n = slice_size * index_size;
114
  dim3 grid = dim3((n + block - 1) / block);
115
  phi::backends::gpu::LimitGridDim(ctx, &grid);
Z
zchen0211 已提交
116

117
  GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
118
      p_src, p_index, p_output, index_size, slice_size);
Z
zchen0211 已提交
119 120
}

121 122 123 124 125
template <typename T, typename IndexT = int>
void GPUGatherNd(const phi::GPUContext& ctx,
                 const DenseTensor& input,
                 const DenseTensor& index,
                 DenseTensor* output) {
126
  const auto gplace = ctx.GetPlace();
127
  auto cplace = phi::CPUPlace();
128 129 130 131 132 133 134 135 136 137 138 139 140

  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();
  auto input_dims = input.dims();
  auto input_dims_size = input_dims.size();

  const T* p_input = input.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
141 142
  auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = phi::product(remain_ddim);
143 144 145 146 147 148
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < input_dims_size; ++i) {
    slice_size *= input_dims[i];
  }
  // source dim
149
  Dim<DDim::kMaxRank> g_input_dims;
150
  for (int i = 0; i < input_dims_size; ++i) {
151
    g_input_dims[i] = input_dims[i];
152 153 154
  }

  int block = 512;
155
  int64_t n = slice_size * remain_numel;
156
  dim3 grid = dim3((n + block - 1) / block);
157
  phi::backends::gpu::LimitGridDim(ctx, &grid);
158

159 160 161 162 163 164 165
  GatherNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_input,
                                                                  g_input_dims,
                                                                  p_index,
                                                                  p_output,
                                                                  remain_numel,
                                                                  slice_size,
                                                                  end_size);
166 167
}

168
template <typename T, typename U>
169 170 171 172 173
__global__ void GatherGPUKernel(const T* input,
                                const U* index,
                                T* out,
                                int64_t outer_dim_size,
                                int64_t inner_dim_size,
174
                                int64_t out_index_dim_size,
175 176
                                int64_t input_index_dim_size,
                                int64_t size) {
177 178
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  int64_t outer_size = outer_dim_size * out_index_dim_size;
179
  for (; idx < size; idx += blockDim.x * gridDim.x) {
180 181 182 183
    int64_t inner_dim_index = idx / outer_size;
    int64_t next_idx = idx - outer_size * inner_dim_index;
    int64_t index_dim_index = next_idx / outer_dim_size;
    U index_val = index[index_dim_index];
184 185 186 187 188 189 190

    PADDLE_ENFORCE(
        index_val >= 0 && index_val < input_index_dim_size,
        "The index is out of bounds, "
        "please check whether the dimensions of index and "
        "input meet the requirements. It should "
        "be less than [%d] and greater than or equal to 0, but received [%d]",
191 192
        input_index_dim_size,
        index_val);
193

194 195
    int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index;
    int64_t input_index =
196
        inner_dim_index * (outer_dim_size * input_index_dim_size) +
197
        index_val * outer_dim_size + out_dim_index;
198 199 200 201 202
    out[idx] = input[input_index];
  }
}

template <typename T, typename U>
203 204 205
__global__ void GatherGradGPUKernel(const T* input,
                                    const U* index,
                                    T* out,
206 207 208
                                    int64_t outer_dim_size,
                                    int64_t inner_dim_size,
                                    int64_t input_index_dim_size,
209 210
                                    int64_t out_index_dim_size,
                                    int64_t size) {
211
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
212
  for (; idx < size; idx += blockDim.x * gridDim.x) {
213 214 215 216 217 218 219
    int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
    int64_t next_idx = idx % (outer_dim_size * input_index_dim_size);
    int64_t index_dim_index = next_idx / (outer_dim_size);
    int64_t out_dim_index = next_idx % outer_dim_size;
    int64_t out_index =
        inner_dim_index * (outer_dim_size * out_index_dim_size) +
        index[index_dim_index] * outer_dim_size + out_dim_index;
220 221 222 223
    paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
  }
}

224
template <typename T, typename U>
225 226 227 228 229
void GatherV2CUDAFunction(const DenseTensor* input,
                          const DenseTensor* index,
                          const int axis,
                          DenseTensor* out,
                          const phi::GPUContext& ctx) {
230 231
  int64_t index_size = index->numel();
  int64_t input_size = input->numel();
232 233 234 235 236
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();
  auto* index_data = index->data<U>();

  if (input->numel() == 0) return;
237 238

  int axis_index = axis;
239
  int64_t index_dim_size = input_dim[axis_index];
240

241 242 243
  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
  std::vector<int64_t> out_dim_vec;
244 245 246 247 248 249 250 251 252 253

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
  out_dim_vec.push_back(index_size);
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
254
  auto out_dim = phi::make_ddim(out_dim_vec);
255 256

  out->Resize(out_dim);
257
  auto* out_data = ctx.Alloc<T>(out);
258
  int64_t out_size = out->numel();
Z
Zeng Jinle 已提交
259
  if (out_size == 0) return;
260

261 262
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size);
  auto stream = ctx.stream();
263 264 265 266 267 268 269 270 271 272
  GatherGPUKernel<T, U>
      <<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
          input_data,
          index_data,
          out_data,
          outer_dim_size,
          inner_dim_size,
          index_size,
          index_dim_size,
          out_size);
273 274
}

275
template <typename T, typename U>
276 277 278 279 280
void GatherV2GradCUDAFunction(const DenseTensor* input,
                              const DenseTensor* index,
                              const int axis,
                              DenseTensor* out,
                              const phi::GPUContext& ctx) {
281
  auto* index_data = index->data<U>();
282 283
  int64_t index_size = index->numel();
  int64_t input_size = input->numel();
284 285 286 287
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();

  if (input->numel() == 0) return;
288
  int axis_index = axis;
289
  int64_t input_index_dim_size = input_dim[axis_index];
290

291 292
  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
293 294 295 296 297 298 299 300

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
  }
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
  }

301
  auto* out_data = ctx.Alloc<T>(out);
302
  auto out_dim = out->dims();
303
  int64_t out_index_dim_size = out_dim[axis_index];
304
  phi::funcs::set_constant(ctx, out, 0.0);
305

306 307
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, input_size);
  auto stream = ctx.stream();
308 309 310 311 312 313 314 315 316 317
  GatherGradGPUKernel<T, U>
      <<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
          input_data,
          index_data,
          out_data,
          outer_dim_size,
          inner_dim_size,
          input_index_dim_size,
          out_index_dim_size,
          input_size);
318
}
319 320 321

}  // namespace funcs
}  // namespace phi