scatter.cu.h 9.6 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16
#include <unordered_set>
17
#include <vector>
18

19
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
W
Wang Xin 已提交
20
#include "paddle/phi/backends/gpu/gpu_primitives.h"
21 22
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
23
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zchen0211 已提交
24

25 26
namespace phi {
namespace funcs {
27

28
template <typename T, typename IndexT = int>
29 30
__global__ void ScatterInitCUDAKernel(const IndexT* indices,
                                      T* output,
31
                                      size_t output_count,
32 33
                                      size_t index_size,
                                      size_t slice_size) {
Z
Zeng Jinle 已提交
34 35 36
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
37
    IndexT scatter_i = indices[indices_i];
38

39 40 41 42 43 44 45 46
    PADDLE_ENFORCE(
        scatter_i >= 0 && scatter_i < output_count,
        "The index is out of bounds, "
        "please check whether the dimensions of index and "
        "input meet the requirements. It should "
        "be less than [%d] and greater or equal to 0, but received [%d]",
        output_count,
        scatter_i);
47

Z
Zeng Jinle 已提交
48
    int64_t out_i = scatter_i * slice_size + slice_i;
49 50 51
    *(output + out_i) = static_cast<T>(0);
  }
}
Z
zchen0211 已提交
52

53
template <typename T, typename IndexT = int>
54 55 56
__global__ void ScatterCUDAKernel(const T* params,
                                  const IndexT* indices,
                                  T* output,
57
                                  size_t output_count,
58 59 60
                                  size_t index_size,
                                  size_t slice_size,
                                  bool overwrite) {
Z
Zeng Jinle 已提交
61 62 63
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
64
    IndexT scatter_i = indices[indices_i];
65

66 67 68 69 70 71 72 73
    PADDLE_ENFORCE(
        scatter_i >= 0 && scatter_i < output_count,
        "The index is out of bounds, "
        "please check whether the dimensions of index and "
        "input meet the requirements. It should "
        "be less than [%d] and greater or equal to 0, but received [%d]",
        output_count,
        scatter_i);
74

Z
Zeng Jinle 已提交
75
    int64_t out_i = scatter_i * slice_size + slice_i;
76 77 78
    if (overwrite) {
      *(output + out_i) = *(params + i);
    } else {
W
Wang Xin 已提交
79
      phi::CudaAtomicAdd(output + out_i, *(params + i));
80
    }
Z
zchen0211 已提交
81 82 83
  }
}

84
template <typename T, typename IndexT = int>
85 86 87
__global__ void ScatterNdCUDAKernel(const T* update,
                                    const IndexT* indices,
                                    T* output,
88
                                    const Dim<DDim::kMaxRank> output_dims,
89 90
                                    size_t remain_size,
                                    size_t slice_size,
91
                                    size_t end_size) {
Z
Zeng Jinle 已提交
92 93 94 95
  CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;  // offset inside the slice
    int64_t gather_i = 0;
96 97 98
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      IndexT index_value = indices[indices_i * end_size + j];
99 100 101 102 103 104 105

      PADDLE_ENFORCE(
          index_value >= 0 && index_value < output_dims[j],
          "The index is out of bounds, "
          "please check whether the dimensions of index and "
          "input meet the requirements. It should "
          "be less than [%d] and greater or equal to 0, but received [%d]",
106 107
          output_dims[j],
          index_value);
108

109 110 111
      gather_i += (index_value * temp);
      temp *= output_dims[j];
    }
Z
Zeng Jinle 已提交
112
    int64_t output_i = gather_i + slice_i;
W
Wang Xin 已提交
113
    phi::CudaAtomicAdd(output + output_i, *(update + i));
114 115 116
  }
}

Z
zchen0211 已提交
117 118 119 120 121
/**
 * A thin wrapper on gpu tensor
 * Return a new updated tensor from source tensor, scatter-assigned according to
 * index
 * input[src]: type-T source Tensor
122
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
123 124
 * return: output tensor
 */
125
template <typename T, typename IndexT = int>
126 127 128 129
void GPUScatterAssign(const phi::GPUContext& ctx,
                      const DenseTensor& src,
                      const DenseTensor& index,
                      DenseTensor* output,
130
                      bool overwrite = true) {
131
  if (index.dims().size() == 2) {
132 133 134 135 136 137 138
    PADDLE_ENFORCE_EQ(
        index.dims()[1],
        1,
        phi::errors::InvalidArgument("index.dims()[1] should be 1 when "
                                     "index.dims().size() = 2 in scatter_op."
                                     "But received value is [%d]",
                                     index.dims()[1]));
139
  } else {
140 141 142 143 144 145 146
    PADDLE_ENFORCE_EQ(
        index.dims().size() == 1 || index.dims().size() == 0,
        true,
        phi::errors::InvalidArgument(
            "index.dims().size() should be 0, 1 or 2 in scatter_op."
            "But received value is [%d]",
            index.dims().size()));
147
  }
148 149

  int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
Z
zchen0211 已提交
150

151
  auto src_dims = src.dims();
152
  phi::DDim output_dims = output->dims();
Z
zchen0211 已提交
153 154

  // slice size
155 156 157 158 159 160
  size_t slice_size = 1;
  if (index.dims().size() != 0) {
    for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
  } else {
    for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
  }
Z
zchen0211 已提交
161

162
  const T* p_src = src.data<T>();
163
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
164
  T* p_output = output->data<T>();
165

166
  const size_t& slice_bytes = slice_size * sizeof(T);
Z
1 api  
zchen0211 已提交
167

168
  // set block and grid num
Z
1 api  
zchen0211 已提交
169
  int block = 512;
Z
Zeng Jinle 已提交
170
  int64_t n = slice_size * index_size;
171
  dim3 grid = dim3((n + block - 1) / block);
172
  phi::backends::gpu::LimitGridDim(ctx, &grid);
Z
1 api  
zchen0211 已提交
173

174 175
  // if not overwrite mode, init data
  if (!overwrite) {
176
    ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
177
        p_index, p_output, output_dims[0], index_size, slice_size);
178 179
  }

180 181 182 183 184 185 186
  ScatterCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_src,
                                                                 p_index,
                                                                 p_output,
                                                                 output_dims[0],
                                                                 index_size,
                                                                 slice_size,
                                                                 overwrite);
Z
zchen0211 已提交
187 188
}

S
ShenLiang 已提交
189 190 191
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
192 193 194
void GPUScatterGradForX(const phi::GPUContext& ctx,
                        const DenseTensor& index,
                        DenseTensor* output) {
195
  int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
S
ShenLiang 已提交
196 197
  auto dst_dims = output->dims();
  // slice size
198 199 200 201 202 203
  int64_t slice_size = 1;  // slice size
  if (index.dims().size() != 0) {
    for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
  } else {
    for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
  }
S
ShenLiang 已提交
204 205 206 207 208 209 210 211
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();
  const size_t& slice_bytes = slice_size * sizeof(T);

  // set block and grid num
  int64_t block = 512;
  int64_t n = slice_size * index_size;
  int64_t height = (n + block - 1) / block;
212
  dim3 grid = dim3((n + block - 1) / block);
213
  phi::backends::gpu::LimitGridDim(ctx, &grid);
S
ShenLiang 已提交
214

215
  ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
216
      p_index, p_output, dst_dims[0], index_size, slice_size);
S
ShenLiang 已提交
217 218
}

219 220 221 222 223
template <typename T, typename IndexT = int>
void GPUScatterNdAdd(const phi::GPUContext& ctx,
                     const DenseTensor& update,
                     const DenseTensor& index,
                     DenseTensor* output) {
224 225 226 227 228 229 230 231 232 233 234 235 236
  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();

  auto output_dims = output->dims();
  auto output_dims_size = output_dims.size();

  const T* p_update = update.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
237 238
  auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = phi::product(remain_ddim);
239 240 241 242 243 244 245
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < output_dims_size; ++i) {
    slice_size *= output_dims[i];
  }
  const size_t slice_bytes = slice_size * sizeof(T);

246
  Dim<DDim::kMaxRank> g_output_dims;
247
  for (int i = 0; i < output_dims_size; ++i) {
248
    g_output_dims[i] = output_dims[i];
249
  }
250

251
  int block = 512;
Z
Zeng Jinle 已提交
252
  int64_t n = slice_size * remain_numel;
253
  dim3 grid = dim3((n + block - 1) / block);
254
  phi::backends::gpu::LimitGridDim(ctx, &grid);
255

256 257 258 259 260 261 262 263
  ScatterNdCUDAKernel<T, IndexT>
      <<<grid, block, 0, ctx.stream()>>>(p_update,
                                         p_index,
                                         p_output,
                                         g_output_dims,
                                         remain_numel,
                                         slice_size,
                                         end_size);
264 265
}

266
}  // namespace funcs
267
}  // namespace phi