scatter.cu.h 9.0 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16
#include <unordered_set>
17
#include <vector>
18
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
19 20
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
21
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zchen0211 已提交
22

23 24
namespace phi {
namespace funcs {
25

26
template <typename T, typename IndexT = int>
27 28 29 30
__global__ void ScatterInitCUDAKernel(const IndexT* indices,
                                      T* output,
                                      size_t index_size,
                                      size_t slice_size) {
Z
Zeng Jinle 已提交
31 32 33
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
34
    IndexT scatter_i = indices[indices_i];
35 36 37 38 39 40 41 42

    PADDLE_ENFORCE(scatter_i >= 0,
                   "The index is out of bounds, "
                   "please check whether the dimensions of index and "
                   "input meet the requirements. It should "
                   "be greater than or equal to 0, but received [%d]",
                   scatter_i);

Z
Zeng Jinle 已提交
43
    int64_t out_i = scatter_i * slice_size + slice_i;
44 45 46
    *(output + out_i) = static_cast<T>(0);
  }
}
Z
zchen0211 已提交
47

48
template <typename T, typename IndexT = int>
49 50 51 52 53 54
__global__ void ScatterCUDAKernel(const T* params,
                                  const IndexT* indices,
                                  T* output,
                                  size_t index_size,
                                  size_t slice_size,
                                  bool overwrite) {
Z
Zeng Jinle 已提交
55 56 57
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
58
    IndexT scatter_i = indices[indices_i];
59 60 61 62 63 64 65 66

    PADDLE_ENFORCE(scatter_i >= 0,
                   "The index is out of bounds, "
                   "please check whether the dimensions of index and "
                   "input meet the requirements. It should "
                   "be greater than or equal to 0, but received [%d]",
                   scatter_i);

Z
Zeng Jinle 已提交
67
    int64_t out_i = scatter_i * slice_size + slice_i;
68 69 70 71 72
    if (overwrite) {
      *(output + out_i) = *(params + i);
    } else {
      paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
    }
Z
zchen0211 已提交
73 74 75
  }
}

76
template <typename T, typename IndexT = int>
77 78 79 80 81 82
__global__ void ScatterNdCUDAKernel(const T* update,
                                    const IndexT* indices,
                                    T* output,
                                    const int64_t* output_dims,
                                    size_t remain_size,
                                    size_t slice_size,
83
                                    size_t end_size) {
Z
Zeng Jinle 已提交
84 85 86 87
  CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
    int64_t gather_i = 0;
88 89 90
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      IndexT index_value = indices[indices_i * end_size + j];
91 92 93 94 95 96 97

      PADDLE_ENFORCE(
          index_value >= 0 && index_value < output_dims[j],
          "The index is out of bounds, "
          "please check whether the dimensions of index and "
          "input meet the requirements. It should "
          "be less than [%d] and greater or equal to 0, but received [%d]",
98 99
          output_dims[j],
          index_value);
100

101 102 103
      gather_i += (index_value * temp);
      temp *= output_dims[j];
    }
Z
Zeng Jinle 已提交
104
    int64_t output_i = gather_i + slice_i;
105 106 107 108
    paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
  }
}

Z
zchen0211 已提交
109 110 111 112 113
/**
 * A thin wrapper on gpu tensor
 * Return a new updated tensor from source tensor, scatter-assigned according to
 * index
 * input[src]: type-T source Tensor
114
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
115 116
 * return: output tensor
 */
117
template <typename T, typename IndexT = int>
118 119 120 121
void GPUScatterAssign(const phi::GPUContext& ctx,
                      const DenseTensor& src,
                      const DenseTensor& index,
                      DenseTensor* output,
122
                      bool overwrite = true) {
Z
zchen0211 已提交
123
  // check index of shape 1-D
124
  if (index.dims().size() == 2) {
125 126 127 128 129 130 131
    PADDLE_ENFORCE_EQ(
        index.dims()[1],
        1,
        phi::errors::InvalidArgument("index.dims()[1] should be 1 when "
                                     "index.dims().size() = 2 in scatter_op."
                                     "But received value is [%d]",
                                     index.dims()[1]));
132
  } else {
133 134 135
    PADDLE_ENFORCE_EQ(index.dims().size(),
                      1,
                      phi::errors::InvalidArgument(
136 137 138
                          "index.dims().size() should be 1 or 2 in scatter_op."
                          "But received value is [%d]",
                          index.dims().size()));
139
  }
Z
Zeng Jinle 已提交
140
  int64_t index_size = index.dims()[0];
Z
zchen0211 已提交
141

142
  auto src_dims = src.dims();
143
  phi::DDim output_dims(src_dims);
Z
zchen0211 已提交
144 145 146
  output_dims[0] = index_size;

  // slice size
Z
Zeng Jinle 已提交
147
  int64_t slice_size = 1;
Z
zchen0211 已提交
148 149
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

150
  const T* p_src = src.data<T>();
151
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
152
  T* p_output = output->data<T>();
153
  const size_t& slice_bytes = slice_size * sizeof(T);
Z
1 api  
zchen0211 已提交
154

155
  // set block and grid num
Z
1 api  
zchen0211 已提交
156
  int block = 512;
Z
Zeng Jinle 已提交
157 158
  int64_t n = slice_size * index_size;
  int64_t grid = (n + block - 1) / block;
Z
1 api  
zchen0211 已提交
159

160 161
  // if not overwrite mode, init data
  if (!overwrite) {
162
    ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
S
ShenLiang 已提交
163
        p_index, p_output, index_size, slice_size);
164 165
  }

166
  ScatterCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
167
      p_src, p_index, p_output, index_size, slice_size, overwrite);
Z
zchen0211 已提交
168 169
}

S
ShenLiang 已提交
170 171 172
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
173 174 175
void GPUScatterGradForX(const phi::GPUContext& ctx,
                        const DenseTensor& index,
                        DenseTensor* output) {
Z
Zeng Jinle 已提交
176
  int64_t index_size = index.dims()[0];
S
ShenLiang 已提交
177 178
  auto dst_dims = output->dims();
  // slice size
Z
Zeng Jinle 已提交
179
  int64_t slice_size = 1;
S
ShenLiang 已提交
180 181 182 183 184 185 186 187 188 189
  for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();
  const size_t& slice_bytes = slice_size * sizeof(T);

  // set block and grid num
  int64_t block = 512;
  int64_t n = slice_size * index_size;
  int64_t height = (n + block - 1) / block;

190
  int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
S
ShenLiang 已提交
191 192
  int64_t grid = height < max_grid_dimx ? height : max_grid_dimx;

193
  ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
S
ShenLiang 已提交
194 195 196
      p_index, p_output, index_size, slice_size);
}

197 198 199 200 201
template <typename T, typename IndexT = int>
void GPUScatterNdAdd(const phi::GPUContext& ctx,
                     const DenseTensor& update,
                     const DenseTensor& index,
                     DenseTensor* output) {
202 203 204 205 206 207 208 209 210 211 212 213 214
  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();

  auto output_dims = output->dims();
  auto output_dims_size = output_dims.size();

  const T* p_update = update.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
215 216
  auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = phi::product(remain_ddim);
217 218 219 220 221 222 223 224
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < output_dims_size; ++i) {
    slice_size *= output_dims[i];
  }
  const size_t slice_bytes = slice_size * sizeof(T);
  // put output_dims int CUDA
  // gplace and cplace
225
  const auto gplace = ctx.GetPlace();
226
  auto cplace = phi::CPUPlace();
227

Z
Zeng Jinle 已提交
228
  std::vector<int64_t> v_output_dims(output_dims_size);
229
  for (int i = 0; i < output_dims_size; ++i) {
Z
Zeng Jinle 已提交
230
    v_output_dims[i] = output_dims[i];
231
  }
232 233 234 235

  phi::DenseTensor out_dims_tensor;
  out_dims_tensor.Resize({output_dims_size});
  auto* g_output_dims = ctx.Alloc<int64_t>(&out_dims_tensor);
Z
Zeng Jinle 已提交
236
  int64_t bytes = output_dims_size * sizeof(int64_t);
237 238
  paddle::memory::Copy(
      gplace, g_output_dims, cplace, v_output_dims.data(), bytes, ctx.stream());
239 240

  int block = 512;
Z
Zeng Jinle 已提交
241 242
  int64_t n = slice_size * remain_numel;
  int64_t grid = (n + block - 1) / block;
243

244 245 246 247 248 249 250
  ScatterNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
      p_update,
      p_index,
      p_output,
      g_output_dims,
      remain_numel,
      slice_size,
251 252 253
      end_size);
}

254 255
}  // namespace funcs
}  // namespace pten