scatter.cu.h 7.1 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14 15

#pragma once
16
#include <unordered_set>
17
#include <vector>
18
#include "math/math_function.h"
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/tensor.h"
20
#include "paddle/fluid/memory/malloc.h"
21
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/platform/place.h"
Z
zchen0211 已提交
23 24 25 26

namespace paddle {
namespace operators {

27 28
using Tensor = framework::Tensor;

29 30 31 32
template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
                                      size_t index_size, size_t slice_size,
                                      bool overwrite) {
33
  CUDA_KERNEL_LOOP(i, index_size * slice_size) {
34 35 36 37 38 39 40
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT scatter_i = indices[indices_i];
    IndexT out_i = scatter_i * slice_size + slice_i;
    *(output + out_i) = static_cast<T>(0);
  }
}
Z
zchen0211 已提交
41

42 43
template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
Z
zchen0211 已提交
44
                                  T* output, size_t index_size,
45
                                  size_t slice_size, bool overwrite) {
46
  CUDA_KERNEL_LOOP(i, index_size * slice_size) {
Z
zchen0211 已提交
47 48
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
49 50
    IndexT scatter_i = indices[indices_i];
    IndexT out_i = scatter_i * slice_size + slice_i;
51 52 53 54 55
    if (overwrite) {
      *(output + out_i) = *(params + i);
    } else {
      paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
    }
Z
zchen0211 已提交
56 57 58
  }
}

59 60 61 62 63
template <typename T, typename IndexT = int>
__global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
                                    T* output, const int* output_dims,
                                    size_t remain_size, size_t slice_size,
                                    size_t end_size) {
64
  CUDA_KERNEL_LOOP(i, remain_size * slice_size) {
65 66 67 68 69 70 71 72 73 74 75 76 77 78
    int indices_i = i / slice_size;
    int slice_i = i - indices_i * slice_size;  // offset inside the slice
    IndexT gather_i = 0;
    int64_t temp = slice_size;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      IndexT index_value = indices[indices_i * end_size + j];
      gather_i += (index_value * temp);
      temp *= output_dims[j];
    }
    IndexT output_i = gather_i + slice_i;
    paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
  }
}

Z
zchen0211 已提交
79 80 81 82 83
/**
 * A thin wrapper on gpu tensor
 * Return a new updated tensor from source tensor, scatter-assigned according to
 * index
 * input[src]: type-T source Tensor
84
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
85 86
 * return: output tensor
 */
87
template <typename T, typename IndexT = int>
88 89 90
void GPUScatterAssign(const framework::ExecutionContext& context,
                      const Tensor& src, const Tensor& index, Tensor* output,
                      bool overwrite = true) {
Z
zchen0211 已提交
91
  // check index of shape 1-D
92
  const auto& ctx = context.device_context();
93 94
  if (index.dims().size() == 2) {
    PADDLE_ENFORCE_EQ(index.dims()[1], 1,
95 96 97 98 99
                      platform::errors::InvalidArgument(
                          "index.dims()[1] should be 1 when "
                          "index.dims().size() = 2 in scatter_op."
                          "But received value is [%d]",
                          index.dims()[1]));
100 101
  } else {
    PADDLE_ENFORCE_EQ(index.dims().size(), 1,
102 103 104 105
                      platform::errors::InvalidArgument(
                          "index.dims().size() should be 1 or 2 in scatter_op."
                          "But received value is [%d]",
                          index.dims().size()));
106
  }
107
  int index_size = index.dims()[0];
Z
zchen0211 已提交
108

109
  auto src_dims = src.dims();
Z
zchen0211 已提交
110 111 112 113 114 115 116
  framework::DDim output_dims(src_dims);
  output_dims[0] = index_size;

  // slice size
  int slice_size = 1;
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

117
  const T* p_src = src.data<T>();
118
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
119
  T* p_output = output->data<T>();
120
  const size_t& slice_bytes = slice_size * sizeof(T);
Z
1 api  
zchen0211 已提交
121

122
  // set block and grid num
Z
1 api  
zchen0211 已提交
123 124 125 126
  int block = 512;
  int n = slice_size * index_size;
  int grid = (n + block - 1) / block;

127 128 129 130 131 132 133 134
  // if not overwrite mode, init data
  if (!overwrite) {
    ScatterInitCUDAKernel<T, IndexT><<<
        grid, block, 0,
        reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
        p_index, p_output, index_size, slice_size, overwrite);
  }

135
  ScatterCUDAKernel<T, IndexT><<<
Z
zchen0211 已提交
136 137
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
138
      p_src, p_index, p_output, index_size, slice_size, overwrite);
Z
zchen0211 已提交
139 140
}

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUScatterNdAdd(const framework::ExecutionContext& context,
                     const Tensor& update, const Tensor& index,
                     Tensor* output) {
  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();

  auto output_dims = output->dims();
  auto output_dims_size = output_dims.size();

  const T* p_update = update.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
  auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = framework::product(remain_ddim);
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < output_dims_size; ++i) {
    slice_size *= output_dims[i];
  }
  const size_t slice_bytes = slice_size * sizeof(T);
  // put output_dims int CUDA
  // gplace and cplace
  const auto& ctx = context.template device_context<DeviceContext>();
169
  const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
170 171 172 173 174 175 176 177
  auto cplace = platform::CPUPlace();

  std::vector<int> v_output_dims(output_dims_size);
  for (int i = 0; i < output_dims_size; ++i) {
    v_output_dims[i] = static_cast<int>(output_dims[i]);
  }
  auto& dev_ctx = context.cuda_device_context();
  int bytes = output_dims_size * sizeof(int);
178
  auto output_dims_ptr = memory::Alloc(dev_ctx, bytes);
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
  int* g_output_dims = reinterpret_cast<int*>(output_dims_ptr->ptr());
  memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes,
               ctx.stream());

  int block = 512;
  int n = slice_size * remain_numel;
  int grid = (n + block - 1) / block;

  ScatterNdCUDAKernel<T, IndexT><<<
      grid, block, 0,
      reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
      p_update, p_index, p_output, g_output_dims, remain_numel, slice_size,
      end_size);
}

Z
zchen0211 已提交
194 195
}  // namespace operators
}  // namespace paddle