scatter.cu.h 7.0 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16
#include <unordered_set>
17
#include <vector>
18
#include "math/math_function.h"
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/tensor.h"
20
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
21
#include "paddle/fluid/platform/place.h"
Z
zchen0211 已提交
22 23 24 25

namespace paddle {
namespace operators {

26 27
using Tensor = framework::Tensor;

Z
zchen0211 已提交
28 29 30
#define CUDA_1D_KERNEL_LOOP(i, n)                              \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)
31 32 33 34 35 36 37 38 39 40 41 42
template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
                                      size_t index_size, size_t slice_size,
                                      bool overwrite) {
  CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT scatter_i = indices[indices_i];
    IndexT out_i = scatter_i * slice_size + slice_i;
    *(output + out_i) = static_cast<T>(0);
  }
}
Z
zchen0211 已提交
43

44 45
template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
Z
zchen0211 已提交
46
                                  T* output, size_t index_size,
47
                                  size_t slice_size, bool overwrite) {
Z
zchen0211 已提交
48 49 50
  CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
51 52
    IndexT scatter_i = indices[indices_i];
    IndexT out_i = scatter_i * slice_size + slice_i;
53 54 55 56 57
    if (overwrite) {
      *(output + out_i) = *(params + i);
    } else {
      paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
    }
Z
zchen0211 已提交
58 59 60
  }
}

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename T, typename IndexT = int>
__global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
                                    T* output, const int* output_dims,
                                    size_t remain_size, size_t slice_size,
                                    size_t end_size) {
  CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) {
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT gather_i = 0;
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      IndexT index_value = indices[indices_i * end_size + j];
      gather_i += (index_value * temp);
      temp *= output_dims[j];
    }
    IndexT output_i = gather_i + slice_i;
    paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
  }
}

Z
zchen0211 已提交
81 82 83 84 85
/**
 * A thin wrapper on gpu tensor
 * Return a new updated tensor from source tensor, scatter-assigned according to
 * index
 * input[src]: type-T source Tensor
86
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
87 88
 * return: output tensor
 */
89
template <typename T, typename IndexT = int>
90 91 92
void GPUScatterAssign(const framework::ExecutionContext& context,
                      const Tensor& src, const Tensor& index, Tensor* output,
                      bool overwrite = true) {
Z
zchen0211 已提交
93
  // check index of shape 1-D
94
  const auto& ctx = context.device_context();
95 96 97 98 99 100 101 102
  if (index.dims().size() == 2) {
    PADDLE_ENFORCE_EQ(index.dims()[1], 1,
                      "index.dims()[1] should be 1 when index.dims().size() == "
                      "2 in scatter_op.");
  } else {
    PADDLE_ENFORCE_EQ(index.dims().size(), 1,
                      "index.dims().size() should be 1 or 2 in scatter_op.");
  }
103
  int index_size = index.dims()[0];
Z
zchen0211 已提交
104

105
  auto src_dims = src.dims();
Z
zchen0211 已提交
106 107 108 109 110 111 112
  framework::DDim output_dims(src_dims);
  output_dims[0] = index_size;

  // slice size
  int slice_size = 1;
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

113
  const T* p_src = src.data<T>();
114
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
115
  T* p_output = output->data<T>();
116
  const size_t& slice_bytes = slice_size * sizeof(T);
Z
1 api  
zchen0211 已提交
117

118
  // set block and grid num
Z
1 api  
zchen0211 已提交
119 120 121 122
  int block = 512;
  int n = slice_size * index_size;
  int grid = (n + block - 1) / block;

123 124 125 126 127 128 129 130
  // if not overwrite mode, init data
  if (!overwrite) {
    ScatterInitCUDAKernel<T, IndexT><<<
        grid, block, 0,
        reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
        p_index, p_output, index_size, slice_size, overwrite);
  }

131
  ScatterCUDAKernel<T, IndexT><<<
Z
zchen0211 已提交
132 133
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
134
      p_src, p_index, p_output, index_size, slice_size, overwrite);
Z
zchen0211 已提交
135 136
}

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUScatterNdAdd(const framework::ExecutionContext& context,
                     const Tensor& update, const Tensor& index,
                     Tensor* output) {
  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();

  auto output_dims = output->dims();
  auto output_dims_size = output_dims.size();

  const T* p_update = update.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
  auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = framework::product(remain_ddim);
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < output_dims_size; ++i) {
    slice_size *= output_dims[i];
  }
  const size_t slice_bytes = slice_size * sizeof(T);
  // put output_dims int CUDA
  // gplace and cplace
  const auto& ctx = context.template device_context<DeviceContext>();
  const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
  auto cplace = platform::CPUPlace();

  std::vector<int> v_output_dims(output_dims_size);
  for (int i = 0; i < output_dims_size; ++i) {
    v_output_dims[i] = static_cast<int>(output_dims[i]);
  }
  auto& dev_ctx = context.cuda_device_context();
  auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
  int bytes = output_dims_size * sizeof(int);
  auto output_dims_ptr = allocator.Allocate(bytes);
  int* g_output_dims = reinterpret_cast<int*>(output_dims_ptr->ptr());
  memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes,
               ctx.stream());

  int block = 512;
  int n = slice_size * remain_numel;
  int grid = (n + block - 1) / block;

  ScatterNdCUDAKernel<T, IndexT><<<
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
      p_update, p_index, p_output, g_output_dims, remain_numel, slice_size,
      end_size);
}

Z
zchen0211 已提交
191 192
}  // namespace operators
}  // namespace paddle