scatter.cu.h 8.9 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16
#include <unordered_set>
17
#include <vector>
18

19
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
20
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
21 22
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
23
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zchen0211 已提交
24

25 26
namespace phi {
namespace funcs {
27

28
template <typename T, typename IndexT = int>
29 30 31 32
__global__ void ScatterInitCUDAKernel(const IndexT* indices,
                                      T* output,
                                      size_t index_size,
                                      size_t slice_size) {
Z
Zeng Jinle 已提交
33 34 35
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
36
    IndexT scatter_i = indices[indices_i];
37 38 39 40 41 42 43 44

    PADDLE_ENFORCE(scatter_i >= 0,
                   "The index is out of bounds, "
                   "please check whether the dimensions of index and "
                   "input meet the requirements. It should "
                   "be greater than or equal to 0, but received [%d]",
                   scatter_i);

Z
Zeng Jinle 已提交
45
    int64_t out_i = scatter_i * slice_size + slice_i;
46 47 48
    *(output + out_i) = static_cast<T>(0);
  }
}
Z
zchen0211 已提交
49

50
template <typename T, typename IndexT = int>
51 52 53 54 55 56
__global__ void ScatterCUDAKernel(const T* params,
                                  const IndexT* indices,
                                  T* output,
                                  size_t index_size,
                                  size_t slice_size,
                                  bool overwrite) {
Z
Zeng Jinle 已提交
57 58 59
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
60
    IndexT scatter_i = indices[indices_i];
61 62 63 64 65 66 67 68

    PADDLE_ENFORCE(scatter_i >= 0,
                   "The index is out of bounds, "
                   "please check whether the dimensions of index and "
                   "input meet the requirements. It should "
                   "be greater than or equal to 0, but received [%d]",
                   scatter_i);

Z
Zeng Jinle 已提交
69
    int64_t out_i = scatter_i * slice_size + slice_i;
70 71 72 73 74
    if (overwrite) {
      *(output + out_i) = *(params + i);
    } else {
      paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
    }
Z
zchen0211 已提交
75 76 77
  }
}

78
template <typename T, typename IndexT = int>
79 80 81
__global__ void ScatterNdCUDAKernel(const T* update,
                                    const IndexT* indices,
                                    T* output,
82
                                    const Dim<DDim::kMaxRank> output_dims,
83 84
                                    size_t remain_size,
                                    size_t slice_size,
85
                                    size_t end_size) {
Z
Zeng Jinle 已提交
86 87 88 89
  CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
    int64_t gather_i = 0;
90 91 92
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      IndexT index_value = indices[indices_i * end_size + j];
93 94 95 96 97 98 99

      PADDLE_ENFORCE(
          index_value >= 0 && index_value < output_dims[j],
          "The index is out of bounds, "
          "please check whether the dimensions of index and "
          "input meet the requirements. It should "
          "be less than [%d] and greater or equal to 0, but received [%d]",
100 101
          output_dims[j],
          index_value);
102

103 104 105
      gather_i += (index_value * temp);
      temp *= output_dims[j];
    }
Z
Zeng Jinle 已提交
106
    int64_t output_i = gather_i + slice_i;
107 108 109 110
    paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
  }
}

Z
zchen0211 已提交
111 112 113 114 115
/**
 * A thin wrapper on gpu tensor
 * Return a new updated tensor from source tensor, scatter-assigned according to
 * index
 * input[src]: type-T source Tensor
116
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
117 118
 * return: output tensor
 */
119
template <typename T, typename IndexT = int>
120 121 122 123
void GPUScatterAssign(const phi::GPUContext& ctx,
                      const DenseTensor& src,
                      const DenseTensor& index,
                      DenseTensor* output,
124
                      bool overwrite = true) {
Z
zchen0211 已提交
125
  // check index of shape 1-D
126
  if (index.dims().size() == 2) {
127 128 129 130 131 132 133
    PADDLE_ENFORCE_EQ(
        index.dims()[1],
        1,
        phi::errors::InvalidArgument("index.dims()[1] should be 1 when "
                                     "index.dims().size() = 2 in scatter_op."
                                     "But received value is [%d]",
                                     index.dims()[1]));
134
  } else {
135 136 137
    PADDLE_ENFORCE_EQ(index.dims().size(),
                      1,
                      phi::errors::InvalidArgument(
138 139 140
                          "index.dims().size() should be 1 or 2 in scatter_op."
                          "But received value is [%d]",
                          index.dims().size()));
141
  }
Z
Zeng Jinle 已提交
142
  int64_t index_size = index.dims()[0];
Z
zchen0211 已提交
143

144
  auto src_dims = src.dims();
145
  phi::DDim output_dims(src_dims);
Z
zchen0211 已提交
146 147 148
  output_dims[0] = index_size;

  // slice size
Z
Zeng Jinle 已提交
149
  int64_t slice_size = 1;
Z
zchen0211 已提交
150 151
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

152
  const T* p_src = src.data<T>();
153
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
154
  T* p_output = output->data<T>();
155
  const size_t& slice_bytes = slice_size * sizeof(T);
Z
1 api  
zchen0211 已提交
156

157
  // set block and grid num
Z
1 api  
zchen0211 已提交
158
  int block = 512;
Z
Zeng Jinle 已提交
159
  int64_t n = slice_size * index_size;
160 161
  dim3 grid = dim3((n + block - 1) / block);
  paddle::platform::LimitGridDim(ctx, &grid);
Z
1 api  
zchen0211 已提交
162

163 164
  // if not overwrite mode, init data
  if (!overwrite) {
165
    ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
S
ShenLiang 已提交
166
        p_index, p_output, index_size, slice_size);
167 168
  }

169
  ScatterCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
170
      p_src, p_index, p_output, index_size, slice_size, overwrite);
Z
zchen0211 已提交
171 172
}

S
ShenLiang 已提交
173 174 175
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
176 177 178
void GPUScatterGradForX(const phi::GPUContext& ctx,
                        const DenseTensor& index,
                        DenseTensor* output) {
Z
Zeng Jinle 已提交
179
  int64_t index_size = index.dims()[0];
S
ShenLiang 已提交
180 181
  auto dst_dims = output->dims();
  // slice size
Z
Zeng Jinle 已提交
182
  int64_t slice_size = 1;
S
ShenLiang 已提交
183 184 185 186 187 188 189 190 191
  for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();
  const size_t& slice_bytes = slice_size * sizeof(T);

  // set block and grid num
  int64_t block = 512;
  int64_t n = slice_size * index_size;
  int64_t height = (n + block - 1) / block;
192 193
  dim3 grid = dim3((n + block - 1) / block);
  paddle::platform::LimitGridDim(ctx, &grid);
S
ShenLiang 已提交
194

195
  ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
S
ShenLiang 已提交
196 197 198
      p_index, p_output, index_size, slice_size);
}

199 200 201 202 203
template <typename T, typename IndexT = int>
void GPUScatterNdAdd(const phi::GPUContext& ctx,
                     const DenseTensor& update,
                     const DenseTensor& index,
                     DenseTensor* output) {
204 205 206 207 208 209 210 211 212 213 214 215 216
  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();

  auto output_dims = output->dims();
  auto output_dims_size = output_dims.size();

  const T* p_update = update.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
217 218
  auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = phi::product(remain_ddim);
219 220 221 222 223 224 225
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < output_dims_size; ++i) {
    slice_size *= output_dims[i];
  }
  const size_t slice_bytes = slice_size * sizeof(T);

226
  Dim<DDim::kMaxRank> g_output_dims;
227
  for (int i = 0; i < output_dims_size; ++i) {
228
    g_output_dims[i] = output_dims[i];
229
  }
230

231
  int block = 512;
Z
Zeng Jinle 已提交
232
  int64_t n = slice_size * remain_numel;
233 234
  dim3 grid = dim3((n + block - 1) / block);
  paddle::platform::LimitGridDim(ctx, &grid);
235

236 237 238 239 240 241 242 243
  ScatterNdCUDAKernel<T, IndexT>
      <<<grid, block, 0, ctx.stream()>>>(p_update,
                                         p_index,
                                         p_output,
                                         g_output_dims,
                                         remain_numel,
                                         slice_size,
                                         end_size);
244 245
}

246
}  // namespace funcs
247
}  // namespace phi