scatter.cu.h 4.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16 17
#include <unordered_set>
#include "math/math_function.h"
Y
Yi Wang 已提交
18
#include "paddle/fluid/framework/tensor.h"
19
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/place.h"
Z
zchen0211 已提交
21 22 23 24

namespace paddle {
namespace operators {

25 26
using Tensor = framework::Tensor;

Z
zchen0211 已提交
27 28 29
#define CUDA_1D_KERNEL_LOOP(i, n)                              \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)
30 31 32 33 34 35 36 37 38 39 40 41
template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
                                      size_t index_size, size_t slice_size,
                                      bool overwrite) {
  CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT scatter_i = indices[indices_i];
    IndexT out_i = scatter_i * slice_size + slice_i;
    *(output + out_i) = static_cast<T>(0);
  }
}
Z
zchen0211 已提交
42

43 44
template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
Z
zchen0211 已提交
45
                                  T* output, size_t index_size,
46
                                  size_t slice_size, bool overwrite) {
Z
zchen0211 已提交
47 48 49
  CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
50 51
    IndexT scatter_i = indices[indices_i];
    IndexT out_i = scatter_i * slice_size + slice_i;
52 53 54 55 56
    if (overwrite) {
      *(output + out_i) = *(params + i);
    } else {
      paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
    }
Z
zchen0211 已提交
57 58 59 60 61 62 63 64
  }
}

/**
 * A thin wrapper on gpu tensor
 * Return a new updated tensor from source tensor, scatter-assigned according to
 * index
 * input[src]: type-T source Tensor
65
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
66 67
 * return: output tensor
 */
68
template <typename T, typename IndexT = int>
69 70 71
void GPUScatterAssign(const framework::ExecutionContext& context,
                      const Tensor& src, const Tensor& index, Tensor* output,
                      bool overwrite = true) {
Z
zchen0211 已提交
72
  // PADDLE_ENFORCE(platform::is_gpu_place(place));
Z
zchen0211 已提交
73
  // check index of shape 1-D
74 75

  const auto& ctx = context.device_context();
Y
Yibing Liu 已提交
76 77
  PADDLE_ENFORCE(index.dims().size() == 1 ||
                 (index.dims().size() == 2 && index.dims()[1] == 1));
78
  int index_size = index.dims()[0];
Z
zchen0211 已提交
79

80
  auto src_dims = src.dims();
Z
zchen0211 已提交
81 82 83 84 85 86 87
  framework::DDim output_dims(src_dims);
  output_dims[0] = index_size;

  // slice size
  int slice_size = 1;
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

88
  const T* p_src = src.data<T>();
89
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
90
  T* p_output = output->data<T>();
91
  const size_t& slice_bytes = slice_size * sizeof(T);
Z
1 api  
zchen0211 已提交
92

93
  // set block and grid num
Z
1 api  
zchen0211 已提交
94 95 96 97
  int block = 512;
  int n = slice_size * index_size;
  int grid = (n + block - 1) / block;

98 99 100 101 102 103 104 105
  // if not overwrite mode, init data
  if (!overwrite) {
    ScatterInitCUDAKernel<T, IndexT><<<
        grid, block, 0,
        reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
        p_index, p_output, index_size, slice_size, overwrite);
  }

106
  ScatterCUDAKernel<T, IndexT><<<
Z
zchen0211 已提交
107 108
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
109
      p_src, p_index, p_output, index_size, slice_size, overwrite);
Z
zchen0211 已提交
110 111 112 113
}

}  // namespace operators
}  // namespace paddle