cum_kernel.cu 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/cum_kernel.h"

17 18 19 20 21 22 23 24 25 26 27 28 29
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
30 31
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
32 33 34 35 36 37 38 39 40 41 42 43
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, int BLOCK_SIZE>
__device__ void BlockReverse(
    const T* idata, T* odata, int src_base, int dst_base, int valid_item) {
  __shared__ T sh_mem[BLOCK_SIZE];
  int tx = threadIdx.x;

  int offset = tx;
44
  T src_data = static_cast<T>(0);
45 46 47
  int src_offset = BLOCK_SIZE - offset - 1;
  if (src_offset < valid_item) {
    src_data = idata[src_base + src_offset];
48
  }
49
  sh_mem[offset] = src_data;
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

  __syncthreads();
  int out_index = dst_base - offset;
  if (offset < valid_item) {
    int sh_mem_index = BLOCK_SIZE - offset - 1;
    odata[out_index] = sh_mem[sh_mem_index];
  }
}

template <typename T>
__global__ void MatrixRowReverse(const T* matrix_data,
                                 T* reverse_data,
                                 int reverse_size,
                                 int outer_size,
                                 int inner_size) {
  int bx = blockIdx.x;
  int by = blockIdx.y;
  int item_per_block = 1024;

  for (int block_offset = 0; block_offset < reverse_size;
       block_offset += item_per_block) {
    int valid_item = (reverse_size - block_offset > item_per_block)
                         ? item_per_block
                         : reverse_size - block_offset;
    int src_offset =
        bx * reverse_size + block_offset + by * (inner_size * reverse_size);
    int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) +
                     reverse_size - 1 - block_offset;
    if (reverse_size < item_per_block) {
      valid_item = reverse_size;
    }

    BlockReverse<T, 1024>(
        matrix_data, reverse_data, src_offset, dst_offset, valid_item);
  }
}

87
template <typename T, typename Op>
88 89
struct BlockPrefixCallbackOp {
  // Running prefix
90 91 92 93 94 95
  T running_total_;
  Op op_;

  __device__ BlockPrefixCallbackOp(T running_total, Op op)
      : running_total_(running_total), op_(op) {}

96
  // Callback operator to be entered by the first warp of threads in the block.
97
  // tid 0 is responsible for returning a value for seeding the block-wide scan.
98
  __device__ T operator()(T block_aggregate) {
99 100
    T old_prefix = running_total_;
    running_total_ = op_(old_prefix, block_aggregate);
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    return old_prefix;
  }
};

// No bank-conflict transpose
template <typename T, int TILE_DIM, int BLOCK_ROWS>
__global__ void MatrixTranspose(T* odata,
                                const T* idata,
                                size_t height,
                                size_t width) {
  __shared__ T tile[TILE_DIM][TILE_DIM + 1];

  int x = blockIdx.x * TILE_DIM + threadIdx.x;
  int y = blockIdx.y * TILE_DIM + threadIdx.y;
  for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
    if (x < width && (y + j) < height) {
      tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x];
    } else {
      tile[threadIdx.y + j][threadIdx.x] = 0;
    }
  }

  __syncthreads();

  x = blockIdx.y * TILE_DIM + threadIdx.x;  // transpose block offset
  y = blockIdx.x * TILE_DIM + threadIdx.y;

  for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
    if (x < height && (y + j) < width) {
      odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j];
    }
  }
}

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
struct LogAddExp {
  template <typename T>
  __host__ __device__ __forceinline__ T operator()(const T& a,
                                                   const T& b) const {
    return std::log(1 + std::exp(std::min(a, b) - std::max(a, b))) +
           std::max(a, b);
  }
};

template <typename T, typename op>
struct Identity;

template <typename T>
struct Identity<T, cub::Sum> {
  static constexpr T value = 0;
};

template <typename T>
struct Identity<T, LogAddExp> {
  static constexpr T value = std::numeric_limits<T>::lowest();
};

template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
158 159 160 161 162
__global__ void BlockScanKernel(T* d_out,
                                const T* d_in,
                                int inner_size,
                                int outer_size,
                                int scan_size,
163 164
                                bool exclusive,
                                Op op) {
165 166
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

167 168
  // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
  typedef cub::
169
      BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
170
          BlockLoadT;
171 172 173 174 175 176
  typedef cub::BlockStore<MT,
                          BLOCK_THREADS,
                          ITEMS_PER_THREAD,
                          cub::BLOCK_STORE_TRANSPOSE>
      BlockStoreT;
  typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
177 178 179 180 181 182 183 184
  // Allocate type-safe, repurposable shared memory for collectives
  __shared__ union {
    typename BlockLoadT::TempStorage load;
    typename BlockStoreT::TempStorage store;
    typename BlockScanT::TempStorage scan;
  } temp_storage;

  int bx = blockIdx.x;
185
  BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
186 187 188 189 190 191 192 193 194 195 196 197

  // Obtain this block's segment of consecutive keys (blocked across threads)
  int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
  for (int block_offset = 0; block_offset < scan_size;
       block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
    int valid_item = (scan_size - block_offset > item_per_block)
                         ? item_per_block
                         : (scan_size - block_offset);
    if (scan_size < item_per_block) {
      valid_item = scan_size;
    }

198
    int offset = block_offset + bx * scan_size;
199

200
    MT thread_keys[ITEMS_PER_THREAD];
201 202 203 204 205 206
    BlockLoadT(temp_storage.load)
        .Load(d_in + offset, thread_keys, valid_item, 0);

    __syncthreads();
    if (exclusive) {
      BlockScanT(temp_storage.scan)
207
          .ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
208 209
    } else {
      BlockScanT(temp_storage.scan)
210
          .InclusiveScan(thread_keys, thread_keys, op, prefix_op);
211 212 213 214 215 216 217 218
    }
    __syncthreads();

    BlockStoreT(temp_storage.store)
        .Store(d_out + offset, thread_keys, valid_item);
  }
}

219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
template <typename Context, typename T>
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
                   const T* in_data,
                   T* out_data,
                   int64_t size,
                   bool reverse,
                   bool exclusive) {
#ifdef __HIPCC__
  const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
  const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
  if (reverse) {
    thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
        thrust::device_pointer_cast(in_data) + size);
    thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
        thrust::device_pointer_cast(out_data) + size);
    if (exclusive) {
      thrust::exclusive_scan(
          policy, reversed_in, reversed_in + size, reversed_out);
    } else {
      thrust::inclusive_scan(
          policy, reversed_in, reversed_in + size, reversed_out);
    }
  } else {
    if (exclusive) {
      thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
    } else {
      thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
    }
  }

  return;
}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
                   const phi::dtype::float16* in_data,
                   phi::dtype::float16* out_data,
                   int64_t size,
                   bool reverse,
                   bool exclusive) {}

264 265 266 267 268 269 270 271 272
template <typename T, typename Context, typename Op>
void ScanKernel(const Context& dev_ctx,
                const DenseTensor& x,
                int axis,
                bool flatten,
                bool exclusive,
                bool reverse,
                Op op,
                DenseTensor* out) {
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
  auto out_dims = out->dims();
  auto size = x.numel();

  PADDLE_ENFORCE_EQ(
      axis < out_dims.size() && axis >= (0 - out_dims.size()),
      true,
      phi::errors::OutOfRange(
          "Attr(axis) is out of range, It's expected "
          "to be in range of [-%d, %d]. But received Attr(axis) = %d.",
          out_dims.size(),
          out_dims.size() - 1,
          axis));
  if (axis < 0) {
    axis += out_dims.size();
  }

  T* out_data = dev_ctx.template Alloc<T>(out);
  const T* in_data = x.data<T>();

  // Use thrust for parallel acceleration when the input size is equal to the
  // length of the ‘axis’ dimension.
294 295
  if (!std::is_same<T, phi::dtype::float16>::value &&
      std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
296 297
    ThrustCumsumKernel<Context, T>(
        dev_ctx, in_data, out_data, size, reverse, exclusive);
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
    return;
  }

  size_t height = 1;
  size_t width = 1;
  for (size_t i = 0; i <= axis; i++) {
    height *= out_dims[i];
  }

  for (size_t i = axis + 1; i < out_dims.size(); i++) {
    width *= out_dims[i];
  }
  int scan_size = out_dims[axis];
  bool transpose = (axis != out_dims.size() - 1);

  int tile_size = 32;
  dim3 blocks(32, 8);
  dim3 transpose_grids((width + tile_size - 1) / tile_size,
                       (height + tile_size - 1) / tile_size);
317 318 319
  DenseTensor tmp_tensor;
  tmp_tensor.Resize(out_dims);
  auto* tmp_data = dev_ctx.template Alloc<T>(&tmp_tensor);
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351

  T* next_in_data = out_data;
  T* next_out_data = tmp_data;
  if (transpose) {
    MatrixTranspose<T, 32, 8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
        out_data, in_data, height, width);
    next_in_data = out_data;
    next_out_data = tmp_data;
  }
  auto swap_ptr = [](T*& ptr1, T*& ptr2) {
    T* tmp = ptr2;
    ptr2 = ptr1;
    ptr1 = tmp;
  };
  int outer_size = height / scan_size;
  int inner_size = width;
  // Consider the size of shared memory, here block size is 128
  dim3 scan_grid(outer_size, inner_size);
  dim3 reverse_grid = scan_grid;
  if (reverse) {
    if (transpose) {
      reverse_grid.x = scan_grid.y;
      reverse_grid.y = scan_grid.x;
      MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
          next_in_data, next_out_data, scan_size, outer_size, inner_size);
      if (!transpose) next_in_data = tmp_data;
      swap_ptr(next_in_data, next_out_data);
    } else {
      MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
          in_data, out_data, scan_size, outer_size, inner_size);
    }
  }
352
  int64_t grid_size = outer_size * inner_size;
353
  if (!transpose && !reverse) {
354
    BlockScanKernel<T, 128, 4, Op><<<grid_size, 128, 0, dev_ctx.stream()>>>(
355
        out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);
356 357

  } else {
358
    BlockScanKernel<T, 128, 4, Op>
359
        <<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
360 361 362 363
                                                  next_in_data,
                                                  outer_size,
                                                  inner_size,
                                                  scan_size,
364 365
                                                  exclusive,
                                                  op);
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
  }
  swap_ptr(next_in_data, next_out_data);
  if (reverse) {
    MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
        next_in_data, next_out_data, scan_size, outer_size, inner_size);
    swap_ptr(next_in_data, next_out_data);
  }
  if (transpose) {
    transpose_grids.x = (height + tile_size - 1) / tile_size;
    transpose_grids.y = (width + tile_size - 1) / tile_size;
    MatrixTranspose<T, 32, 8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
        next_out_data, next_in_data, width, height);
  }
}

381 382 383
template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
                  const DenseTensor& x,
W
WangZhen 已提交
384
                  const Scalar& axis,
385 386 387 388 389 390 391
                  bool flatten,
                  bool exclusive,
                  bool reverse,
                  DenseTensor* out) {
  using Op = cub::Sum;
  auto op = Op();
  ScanKernel<T, Context, Op>(
W
WangZhen 已提交
392
      dev_ctx, x, axis.to<int>(), flatten, exclusive, reverse, op, out);
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
}

template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
                        const DenseTensor& x,
                        int axis,
                        bool flatten,
                        bool exclusive,
                        bool reverse,
                        DenseTensor* out) {
  using Op = LogAddExp;
  auto op = Op();
  ScanKernel<T, Context, Op>(
      dev_ctx, x, axis, flatten, exclusive, reverse, op, out);
}

409 410
}  // namespace phi

411
#ifdef PADDLE_WITH_HIP
412 413 414 415 416 417 418 419 420
PD_REGISTER_KERNEL(cumsum,
                   GPU,
                   ALL_LAYOUT,
                   phi::CumsumKernel,
                   float,
                   double,
                   int16_t,
                   int,
                   int64_t) {}
421

422 423
PD_REGISTER_KERNEL(
    logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
#else
PD_REGISTER_KERNEL(cumsum,
                   GPU,
                   ALL_LAYOUT,
                   phi::CumsumKernel,
                   float,
                   double,
                   int16_t,
                   int,
                   int64_t,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(logcumsumexp,
                   GPU,
                   ALL_LAYOUT,
                   phi::LogcumsumexpKernel,
                   float,
                   double,
                   phi::dtype::float16) {}
#endif