gather_scatter_functor.cu 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
17
#include "paddle/phi/backends/gpu/gpu_primitives.h"
18

19 20
namespace phi {
namespace funcs {
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

class TensorAssign {
 public:
  template <typename tensor_t>
  constexpr void operator()(tensor_t* self_data, tensor_t* src_data) const {
    *self_data = *src_data;
  }
};
static TensorAssign tensor_assign;

class ReduceAdd {
 public:
  template <
      typename tensor_t,
      std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
  __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
37
    phi::CudaAtomicAdd(self_data, *src_data);
38 39 40 41 42 43 44 45 46
  }
  template <typename tensor_t,
            std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
  __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
    *self_data += *src_data;
  }
};
static ReduceAdd reduce_add;

47 48 49 50 51 52 53 54 55 56
class ReduceMul {
 public:
  template <typename tensor_t>
  __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
    *self_data *= *src_data;
    // TODO(huangxu96) platform::CudaAtomicMul(*self_data, *src_data);
  }
};
static ReduceMul reduce_mul;

57 58 59
template <typename tensor_t,
          typename index_t,
          typename func_t,
60
          bool is_scatter_like = true>
61 62 63 64 65 66 67 68 69 70
__global__ void GatherScatterGPUKernel(tensor_t* self_data,
                                       int dim,
                                       const index_t* index_data,
                                       tensor_t* src_data,
                                       int64_t inner_dim_size,
                                       int select_dim_size,
                                       int replaced_select_dim_size,
                                       int64_t outer_dim_size,
                                       int64_t numel,
                                       const func_t& reduce_op) {
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  if (tid >= numel) return;
  int64_t i, j, k;  // The i, j, k here is the index of the 3 layers loop
                    // squeezed from the N layers loop.
  /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
  i = tid / (select_dim_size * outer_dim_size);
  int64_t remind = tid % (select_dim_size * outer_dim_size);
  j = remind / outer_dim_size;
  k = remind % outer_dim_size;
  index_t index = index_data[tid];
  /*
    gather computation formula:

    self[i][j][k] = src[index[i][j][k]][j][k]  # if dim == 0
    self[i][j][k] = src[i][index[i][j][k]][k]  # if dim == 1
    self[i][j][k] = src[i][j][index[i][j][k]]  # if dim == 2

    scatter computation formula:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

  */
  // index matrix has different shape with self matrix or src matrix.
  int64_t replace_index = k + index * outer_dim_size +
                          i * outer_dim_size * replaced_select_dim_size;
  int64_t self_idx = is_scatter_like ? replace_index : tid;
  int64_t src_idx = is_scatter_like ? tid : replace_index;
100 101
  reduce_op(static_cast<tensor_t*>(self_data + self_idx),
            static_cast<tensor_t*>(src_data + src_idx));
102 103
}

104 105
template <typename tensor_t,
          typename index_t = int64_t,
106 107 108
          bool is_scatter_like = true>
struct gpu_gather_scatter_functor {
  template <typename func_t>
109
  void operator()(phi::DenseTensor self,
110
                  int dim,
111
                  const phi::DenseTensor& index,
112
                  phi::DenseTensor src,
113 114
                  const std::string& method_name,
                  const func_t& reduce_op,
115
                  const phi::DeviceContext& ctx) {
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    if (index.numel() == 0) {
      return;
    }
    auto* self_data = self.data<tensor_t>();
    auto* index_data = index.data<index_t>();
    auto* src_data = src.data<tensor_t>();
    int64_t self_size = self.numel();
    int64_t index_size = index.numel();
    int64_t src_size = src.numel();
    auto self_dims = self.dims();
    auto index_dims = index.dims();
    auto src_dims = src.dims();
    if (self_size == 0 || src_size == 0 || index_size == 0) return;
    int select_dim_size = index_dims[dim];
    // index matrix has different shape with self matrix or src matrix.
    int replaced_select_dim_size =
        is_scatter_like ? self_dims[dim] : src_dims[dim];
    int64_t inner_dim_size = 1;
    int64_t outer_dim_size = 1;
135
    for (int64_t i = 0; i < dim; ++i) {
136 137 138 139 140 141 142 143
      inner_dim_size *= index_dims[i];
    }

    for (int i = dim + 1; i < index_dims.size(); i++) {
      outer_dim_size *= index_dims[i];
    }

    int block = 512;
144
    int64_t n = inner_dim_size * select_dim_size * outer_dim_size;
145
    int64_t grid = (n + block - 1) / block;
L
Leo Chen 已提交
146
    auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
147
    GatherScatterGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
148 149 150 151 152 153 154 155 156 157
        <<<grid, block, 0, stream>>>(self_data,
                                     dim,
                                     index_data,
                                     src_data,
                                     inner_dim_size,
                                     select_dim_size,
                                     replaced_select_dim_size,
                                     outer_dim_size,
                                     index_size,
                                     reduce_op);
158 159 160 161
  }
};  // struct gpu_gather_scatter_functor

template <typename tensor_t, typename index_t>
162
void gpu_gather_kernel(phi::DenseTensor self,
163
                       int dim,
164
                       const phi::DenseTensor& index,
165
                       phi::DenseTensor result,
166
                       const phi::DeviceContext& ctx) {
167 168
  gpu_gather_scatter_functor<tensor_t,
                             index_t,
169 170 171 172 173
                             /*is_scatter_like=*/false>()(
      result, dim, index, self, "gather_out_gpu", tensor_assign, ctx);
  return;
}

174
template <typename tensor_t, typename index_t>
175
void gpu_scatter_assign_kernel(phi::DenseTensor self,
176
                               int dim,
177
                               const phi::DenseTensor& index,
178
                               phi::DenseTensor src,
179
                               const phi::DeviceContext& ctx) {
180 181
  gpu_gather_scatter_functor<tensor_t,
                             index_t,
182 183 184 185
                             /*is_scatter_like=*/true>()(
      self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx);
}

186
template <typename tensor_t, typename index_t>
187
void gpu_scatter_add_kernel(phi::DenseTensor self,
188
                            int dim,
189
                            const phi::DenseTensor& index,
190
                            phi::DenseTensor src,
191
                            const phi::DeviceContext& ctx) {
192 193
  gpu_gather_scatter_functor<tensor_t,
                             index_t,
194 195 196 197
                             /*is_scatter_like=*/true>()(
      self, dim, index, src, "scatter_add_gpu", reduce_add, ctx);
}

198
template <typename tensor_t, typename index_t>
199
void gpu_scatter_mul_kernel(phi::DenseTensor self,
200
                            int dim,
201
                            const phi::DenseTensor& index,
202
                            phi::DenseTensor src,
203
                            const phi::DeviceContext& ctx) {
204 205
  gpu_gather_scatter_functor<tensor_t,
                             index_t,
206 207 208 209 210
                             /*is_scatter_like=*/true>()(
      self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx);
}

template <typename tensor_t, typename index_t>
211 212 213 214 215 216 217 218
__global__ void ScatterInputGradGPUKernel(tensor_t* grad_data,
                                          int dim,
                                          const index_t* index_data,
                                          int64_t inner_dim_size,
                                          int select_dim_size,
                                          int grad_select_dim_size,
                                          int64_t outer_dim_size,
                                          int64_t numel) {
219 220 221 222 223 224 225 226 227 228 229 230 231
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  if (tid >= numel) return;
  int64_t i, j, k;
  i = tid / (select_dim_size * outer_dim_size);
  int64_t remind = tid % (select_dim_size * outer_dim_size);
  j = remind / outer_dim_size;
  k = remind % outer_dim_size;
  index_t index = index_data[tid];
  int64_t replace_index =
      k + index * outer_dim_size + i * outer_dim_size * grad_select_dim_size;
  grad_data[replace_index] = 0;
}
template <typename tensor_t, typename index_t>
232
void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
233
                                   int dim,
234
                                   const phi::DenseTensor& index,
235
                                   phi::DenseTensor grad,
236
                                   const phi::DeviceContext& ctx) {
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
  auto* index_data = index.data<index_t>();
  auto* grad_data = grad.data<tensor_t>();

  auto index_dims = index.dims();
  auto grad_dims = grad.dims();
  int64_t index_size = index.numel();

  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
  int select_dim_size = index_dims[dim];
  int grad_select_dim_size = grad_dims[dim];
  for (int64_t i = 0; i < dim; ++i) {
    inner_dim_size *= index_dims[i];
  }

  for (int i = dim + 1; i < index_dims.size(); i++) {
    outer_dim_size *= index_dims[i];
  }

  int block = 512;
257
  int64_t n = inner_dim_size * select_dim_size * outer_dim_size;
258
  int64_t grid = (n + block - 1) / block;
L
Leo Chen 已提交
259
  auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
260

261 262 263 264 265 266 267 268 269
  ScatterInputGradGPUKernel<tensor_t, index_t>
      <<<grid, block, 0, stream>>>(grad_data,
                                   dim,
                                   index_data,
                                   inner_dim_size,
                                   select_dim_size,
                                   grad_select_dim_size,
                                   outer_dim_size,
                                   index_size);
270
}
271
Instantiate_Template_Function(gpu_gather_kernel)
272 273 274 275
    Instantiate_Template_Function(gpu_scatter_assign_kernel)
        Instantiate_Template_Function(gpu_scatter_add_kernel)
            Instantiate_Template_Function(gpu_scatter_mul_kernel)
                Instantiate_Template_Function(gpu_scatter_input_grad_kernel)
276

277 278
}  // namespace funcs
}  // namespace phi