gather.cu.h 5.9 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16 17 18
#include <vector>
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/operator.h"
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/tensor.h"
20
#include "paddle/fluid/memory/malloc.h"
21
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/platform/place.h"
Z
zchen0211 已提交
23 24 25 26 27

namespace paddle {
namespace operators {

using framework::Tensor;
Q
QI JUN 已提交
28
using platform::DeviceContext;
Z
zchen0211 已提交
29 30 31 32 33

#define CUDA_1D_KERNEL_LOOP(i, n)                              \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

34 35 36 37
template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
                                 T* output, size_t index_size,
                                 size_t slice_size) {
Z
zchen0211 已提交
38 39 40
  CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
41 42
    IndexT gather_i = indices[indices_i];
    IndexT params_i = gather_i * slice_size + slice_i;
Z
zchen0211 已提交
43 44 45 46
    *(output + i) = *(params + params_i);
  }
}

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
                                   const IndexT* indices, T* output,
                                   size_t remain_size, size_t slice_size,
                                   size_t end_size) {
  CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT gather_i = 0;
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      auto index_value = indices[indices_i * end_size + j];
      assert(index_value >= 0 && index_value < input_dims[j]);
      gather_i += (index_value * temp);
      temp *= input_dims[j];
    }
    IndexT input_i = gather_i + slice_i;
    *(output + i) = *(input + input_i);
  }
}

Z
zchen0211 已提交
68 69 70 71
/**
 * A thin wrapper on gpu tensor
 * Return a new tensor from source tensor, gathered according to index
 * input[src]: type-T source Tensor
72
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
73 74
 * return: output tensor
 */
75
template <typename T, typename IndexT = int>
76 77
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
               const Tensor& index, Tensor* output) {
Z
zchen0211 已提交
78
  // check index of shape 1-D
C
chengduo 已提交
79 80
  if (index.dims().size() == 1) {
    PADDLE_ENFORCE_GT(index.dims()[0], 0,
81 82 83
                      platform::errors::InvalidArgument(
                          "The index of gather_op should not be empty"
                          "when the index's rank is 1."));
C
chengduo 已提交
84 85
  } else if (index.dims().size() == 2) {
    PADDLE_ENFORCE_EQ(index.dims()[1], 1,
86 87 88
                      platform::errors::InvalidArgument(
                          "If the index's rank of gather_op is 2,"
                          " the second dimension should be 1."));
C
chengduo 已提交
89
  }
Y
Yibing Liu 已提交
90

91
  int index_size = index.dims()[0];
Z
zchen0211 已提交
92

93
  auto src_dims = src.dims();
Z
zchen0211 已提交
94 95 96 97 98 99 100
  framework::DDim output_dims(src_dims);
  output_dims[0] = index_size;

  // slice size
  int slice_size = 1;
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

101
  const T* p_src = src.data<T>();
102
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
103 104 105 106 107
  T* p_output = output->data<T>();

  int block = 512;
  int n = slice_size * index_size;
  int grid = (n + block - 1) / block;
Z
zchen0211 已提交
108

109
  GatherCUDAKernel<T, IndexT><<<
Z
zchen0211 已提交
110 111 112
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
      p_src, p_index, p_output, index_size, slice_size);
Z
zchen0211 已提交
113 114
}

115 116 117 118
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUGatherNd(const framework::ExecutionContext& context,
                 const Tensor& input, const Tensor& index, Tensor* output) {
  const auto& ctx = context.template device_context<DeviceContext>();
119
  const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
  auto cplace = platform::CPUPlace();

  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();
  auto input_dims = input.dims();
  auto input_dims_size = input_dims.size();

  const T* p_input = input.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
  auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = framework::product(remain_ddim);
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < input_dims_size; ++i) {
    slice_size *= input_dims[i];
  }
  // source dim
  std::vector<int> v_input_dims(input_dims_size);
  for (int i = 0; i < input_dims_size; ++i) {
    v_input_dims[i] = static_cast<int>(input_dims[i]);
  }

  auto& dev_ctx = context.cuda_device_context();
  int bytes = input_dims_size * sizeof(int);
149
  auto p_input_dims = memory::Alloc(dev_ctx, bytes);
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
  memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
               ctx.stream());

  int block = 512;
  int n = slice_size * remain_numel;
  int grid = (n + block - 1) / block;

  GatherNdCUDAKernel<T, IndexT><<<
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
      p_input, g_input_dims, p_index, p_output, remain_numel, slice_size,
      end_size);
}

Z
zchen0211 已提交
165 166
}  // namespace operators
}  // namespace paddle