gather.cu.h 11.2 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16 17 18
#include <vector>
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/operator.h"
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/tensor.h"
20
#include "paddle/fluid/memory/malloc.h"
21
#include "paddle/fluid/operators/math/math_function.h"
22
#include "paddle/fluid/platform/cuda_primitives.h"
23
#include "paddle/fluid/platform/gpu_launch_config.h"
Y
Yi Wang 已提交
24
#include "paddle/fluid/platform/place.h"
Z
zchen0211 已提交
25 26 27 28
namespace paddle {
namespace operators {

using framework::Tensor;
Q
QI JUN 已提交
29
using platform::DeviceContext;
Z
zchen0211 已提交
30

31 32 33 34
template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
                                 T* output, size_t index_size,
                                 size_t slice_size) {
35
  CUDA_KERNEL_LOOP(i, index_size * slice_size) {
Z
zchen0211 已提交
36 37
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
38 39
    IndexT gather_i = indices[indices_i];
    IndexT params_i = gather_i * slice_size + slice_i;
Z
zchen0211 已提交
40 41 42 43
    *(output + i) = *(params + params_i);
  }
}

44 45 46 47 48
template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
                                   const IndexT* indices, T* output,
                                   size_t remain_size, size_t slice_size,
                                   size_t end_size) {
49
  CUDA_KERNEL_LOOP(i, remain_size * slice_size) {
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT gather_i = 0;
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      auto index_value = indices[indices_i * end_size + j];
      assert(index_value >= 0 && index_value < input_dims[j]);
      gather_i += (index_value * temp);
      temp *= input_dims[j];
    }
    IndexT input_i = gather_i + slice_i;
    *(output + i) = *(input + input_i);
  }
}

Z
zchen0211 已提交
65 66 67 68
/**
 * A thin wrapper on gpu tensor
 * Return a new tensor from source tensor, gathered according to index
 * input[src]: type-T source Tensor
69
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
70 71
 * return: output tensor
 */
72
template <typename T, typename IndexT = int>
73 74
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
               const Tensor& index, Tensor* output) {
Z
zchen0211 已提交
75
  // check index of shape 1-D
C
chengduo 已提交
76 77
  if (index.dims().size() == 1) {
    PADDLE_ENFORCE_GT(index.dims()[0], 0,
78 79 80
                      platform::errors::InvalidArgument(
                          "The index of gather_op should not be empty"
                          "when the index's rank is 1."));
C
chengduo 已提交
81 82
  } else if (index.dims().size() == 2) {
    PADDLE_ENFORCE_EQ(index.dims()[1], 1,
83 84 85
                      platform::errors::InvalidArgument(
                          "If the index's rank of gather_op is 2,"
                          " the second dimension should be 1."));
C
chengduo 已提交
86
  }
Y
Yibing Liu 已提交
87

88
  int index_size = index.dims()[0];
Z
zchen0211 已提交
89

90
  auto src_dims = src.dims();
Z
zchen0211 已提交
91 92 93 94 95 96 97
  framework::DDim output_dims(src_dims);
  output_dims[0] = index_size;

  // slice size
  int slice_size = 1;
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

98
  const T* p_src = src.data<T>();
99
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
100 101 102 103 104
  T* p_output = output->data<T>();

  int block = 512;
  int n = slice_size * index_size;
  int grid = (n + block - 1) / block;
Z
zchen0211 已提交
105

106
  GatherCUDAKernel<T, IndexT><<<
Z
zchen0211 已提交
107 108 109
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
      p_src, p_index, p_output, index_size, slice_size);
Z
zchen0211 已提交
110 111
}

112 113 114 115
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUGatherNd(const framework::ExecutionContext& context,
                 const Tensor& input, const Tensor& index, Tensor* output) {
  const auto& ctx = context.template device_context<DeviceContext>();
116
  const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  auto cplace = platform::CPUPlace();

  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();
  auto input_dims = input.dims();
  auto input_dims_size = input_dims.size();

  const T* p_input = input.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
  auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = framework::product(remain_ddim);
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < input_dims_size; ++i) {
    slice_size *= input_dims[i];
  }
  // source dim
  std::vector<int> v_input_dims(input_dims_size);
  for (int i = 0; i < input_dims_size; ++i) {
    v_input_dims[i] = static_cast<int>(input_dims[i]);
  }

  auto& dev_ctx = context.cuda_device_context();
  int bytes = input_dims_size * sizeof(int);
146
  auto p_input_dims = memory::Alloc(dev_ctx, bytes);
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
  memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
               ctx.stream());

  int block = 512;
  int n = slice_size * remain_numel;
  int grid = (n + block - 1) / block;

  GatherNdCUDAKernel<T, IndexT><<<
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
      p_input, g_input_dims, p_index, p_output, remain_numel, slice_size,
      end_size);
}

162 163 164 165 166 167
template <typename T, typename U>
__global__ void GatherGPUKernel(const T* input, const U* index, T* out,
                                int outer_dim_size, int inner_dim_size,
                                int out_index_dim_size,
                                int input_index_dim_size, int size) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
168
  int outer_size = outer_dim_size * out_index_dim_size;
169
  for (; idx < size; idx += blockDim.x * gridDim.x) {
170 171 172 173 174
    int inner_dim_index = idx / outer_size;
    int next_idx = idx - outer_size * inner_dim_index;
    int index_dim_index = next_idx / outer_dim_size;
    int index_val = index[index_dim_index];
    int out_dim_index = next_idx - outer_dim_size * index_dim_index;
175 176
    int input_index =
        inner_dim_index * (outer_dim_size * input_index_dim_size) +
177
        index_val * outer_dim_size + out_dim_index;
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    out[idx] = input[input_index];
  }
}

template <typename T, typename U>
__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
                                    int outer_dim_size, int inner_dim_size,
                                    int input_index_dim_size,
                                    int out_index_dim_size, int size) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < size; idx += blockDim.x * gridDim.x) {
    int inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
    int next_idx = idx % (outer_dim_size * input_index_dim_size);
    int index_dim_index = next_idx / (outer_dim_size);
    int out_dim_index = next_idx % outer_dim_size;
    int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) +
                    index[index_dim_index] * outer_dim_size + out_dim_index;
    paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
  }
}

template <typename T, typename U, typename V>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
                          const Tensor* axis, Tensor* out,
                          const paddle::platform::Place& place,
                          const framework::ExecutionContext& ctx) {
  int axis_size = axis->numel();
  int index_size = index->numel();
  int input_size = input->numel();
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();
  auto* index_data = index->data<U>();

  if (input->numel() == 0) return;
  PADDLE_ENFORCE_EQ(axis_size, 1,
                    platform::errors::InvalidArgument(
                        "Axis size should be 1, but received %d", axis_size));
  Tensor cpu_axis;
  framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
  int axis_index = cpu_axis.data<V>()[0];
  int index_dim_size = input_dim[axis_index];

  int inner_dim_size = 1;
  int outer_dim_size = 1;
  std::vector<int> out_dim_vec;

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
  out_dim_vec.push_back(index_size);
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
  auto out_dim = framework::make_ddim(out_dim_vec);

  out->Resize(out_dim);
  auto* out_data = out->mutable_data<T>(place);
  int out_size = out->numel();

239 240
  platform::GpuLaunchConfig config =
      platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
241
  auto stream = ctx.cuda_device_context().stream();
242 243
  GatherGPUKernel<
      T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
      input_data, index_data, out_data, outer_dim_size, inner_dim_size,
      index_size, index_dim_size, out_size);
}

template <typename T, typename U, typename V>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
                              const Tensor* axis, Tensor* out,
                              const paddle::platform::Place& place,
                              const framework::ExecutionContext& ctx) {
  auto* index_data = index->data<U>();

  int axis_size = axis->numel();
  int index_size = index->numel();
  int input_size = input->numel();
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();

  if (input->numel() == 0) return;
  PADDLE_ENFORCE_EQ(axis_size, 1,
                    platform::errors::InvalidArgument(
                        "Axis size should be 1, but received %d", axis_size));
  Tensor cpu_axis;
  framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
  int axis_index = cpu_axis.data<V>()[0];
  int input_index_dim_size = input_dim[axis_index];

  int inner_dim_size = 1;
  int outer_dim_size = 1;

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
  }
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
  }

  auto* out_data = out->mutable_data<T>(place);
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
  auto out_dim = out->dims();
  int out_index_dim_size = out_dim[axis_index];
  operators::math::set_constant(*dev_ctx, out, 0.0);

286 287
  platform::GpuLaunchConfig config =
      platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size);
288
  auto stream = ctx.cuda_device_context().stream();
289 290
  GatherGradGPUKernel<
      T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
291 292 293
      input_data, index_data, out_data, outer_dim_size, inner_dim_size,
      input_index_dim_size, out_index_dim_size, input_size);
}
Z
zchen0211 已提交
294 295
}  // namespace operators
}  // namespace paddle