inclusive_scan.h 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include <thrust/device_ptr.h>
#include <thrust/iterator/reverse_iterator.h>
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/for_range.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/malloc.h"

namespace phi {
namespace funcs {

template <typename T>
struct IsComplex : public std::false_type {};

template <>
struct IsComplex<::phi::dtype::complex<float>> : public std::true_type {};

template <>
struct IsComplex<::phi::dtype::complex<double>> : public std::true_type {};

template <typename InputIterator, typename OutputIterator, typename BinaryOp>
static void CubInclusiveScan(InputIterator x_iter,
                             OutputIterator y_iter,
                             size_t n,
                             BinaryOp op,
                             const phi::GPUContext &dev_ctx) {
  paddle::memory::allocation::AllocationPtr allocation;
  void *temp_storage = nullptr;
  size_t temp_storage_bytes = 0;
  for (size_t i = 0; i < 2; ++i) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        cub::DeviceScan::InclusiveScan(temp_storage,
                                       temp_storage_bytes,
                                       x_iter,
                                       y_iter,
                                       op,
                                       static_cast<int>(n),
                                       dev_ctx.stream()));
    if (i == 0 && temp_storage_bytes > 0) {
      allocation =
          paddle::memory::Alloc(dev_ctx.GetPlace(), temp_storage_bytes);
      temp_storage = allocation->ptr();
    }
  }
}

template <typename T>
static auto MakeThrustReverseIterator(T *x) {
  return thrust::reverse_iterator<thrust::device_ptr<T>>(
      thrust::device_pointer_cast(x));
}

template <typename T, typename BinaryOp, bool kReverse>
struct InclusiveScanOuterOrMidDimFunctor {
  HOSTDEVICE InclusiveScanOuterOrMidDimFunctor(
      const T *x, T *y, size_t mid_dim, size_t inner_dim, T init, BinaryOp op)
      : x_(x),
        y_(y),
        mid_dim_(mid_dim),
        inner_dim_(inner_dim),
        init_(init),
        op_(op) {}

  HOSTDEVICE void operator()(size_t idx) const {
    auto outer_idx = idx / inner_dim_;
    auto inner_idx = idx % inner_dim_;
    if (kReverse) {
      idx = outer_idx * mid_dim_ * inner_dim_ + (mid_dim_ - 1) * inner_dim_ +
            inner_idx;
    } else {
      idx = outer_idx * mid_dim_ * inner_dim_ + inner_idx;
    }

    auto x_ptr = x_ + idx;
    auto y_ptr = y_ + idx;
    T acc_value = init_;
    for (size_t i = 0; i < mid_dim_; ++i) {
      acc_value = op_(acc_value, *x_ptr);
      *y_ptr = acc_value;
      if (kReverse) {
        x_ptr -= inner_dim_;
        y_ptr -= inner_dim_;
      } else {
        x_ptr += inner_dim_;
        y_ptr += inner_dim_;
      }
    }
  }

 private:
  const T *x_;
  T *y_;
  size_t mid_dim_;
  size_t inner_dim_;
  T init_;
  BinaryOp op_;
};

template <typename T,
          typename BinaryOp,
          size_t kThreadNumX,
          size_t kThreadNumY,
          bool kReverse>
static __global__ void InclusiveScanInnerDimCUDAKernel(
    const T *x, T *y, size_t num_rows, size_t row_size, T init, BinaryOp op) {
  using RealT = phi::dtype::Real<T>;
  constexpr auto kSharedBufferSize =
      IsComplex<T>::value ? 4 * kThreadNumX : 2 * kThreadNumX;
  __shared__ RealT sbuf[kThreadNumY][kSharedBufferSize];
  T *row_buf = reinterpret_cast<T *>(sbuf[threadIdx.y]);

  size_t block_row = static_cast<size_t>(blockIdx.x * kThreadNumY);
  size_t block_row_stride = static_cast<size_t>(gridDim.x * kThreadNumY);
  for (; block_row < num_rows; block_row += block_row_stride) {
    size_t row = block_row + threadIdx.y;
    T block_total = init;

    const T *row_x = x + row * row_size;
    T *row_y = y + row * row_size;
    for (size_t block_col = 0; block_col < row_size;
         block_col += 2 * kThreadNumX) {
      size_t col1, col2;
      if (kReverse) {
        col1 = row_size - 1 - block_col - threadIdx.x;
        col2 = col1 - kThreadNumX;
      } else {
        col1 = block_col + threadIdx.x;
        col2 = col1 + kThreadNumX;
      }

      if (row < num_rows) {
        if (col1 < row_size) {
          row_buf[threadIdx.x] = row_x[col1];
        } else {
          row_buf[threadIdx.x] = init;
        }

        if (col2 < row_size) {
          row_buf[kThreadNumX + threadIdx.x] = row_x[col2];
        } else {
          row_buf[kThreadNumX + threadIdx.x] = init;
        }

        if (threadIdx.x == 0) {
          row_buf[0] = op(row_buf[0], block_total);
        }
      }
      __syncthreads();

      for (size_t s = kThreadNumX, d = 1; s >= 1; s >>= 1, d <<= 1) {
        if (row < num_rows && threadIdx.x < s) {
          size_t offset = (2 * threadIdx.x + 1) * d - 1;
          row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]);
        }
        __syncthreads();
      }

      for (size_t s = 2, d = kThreadNumX / 2; d >= 1; s <<= 1, d >>= 1) {
        if (row < num_rows && threadIdx.x < s - 1) {
          size_t offset = 2 * (threadIdx.x + 1) * d - 1;
          row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]);
        }
        __syncthreads();
      }

      if (row < num_rows) {
        if (col1 < row_size) row_y[col1] = row_buf[threadIdx.x];
        if (col2 < row_size) row_y[col2] = row_buf[kThreadNumX + threadIdx.x];
      }
      block_total = row_buf[2 * kThreadNumX - 1];
      __syncthreads();
    }
  }
}

template <typename T, typename BinaryOp>
static void InclusiveScanInnerDim(const T *x,
                                  T *y,
                                  size_t outer_dim,
                                  size_t inner_dim,
                                  T init,
                                  BinaryOp op,
                                  bool reverse,
                                  const phi::GPUContext &dev_ctx) {
  constexpr size_t kThreadNumX = 16;
  constexpr size_t kThreadNumY = 32;

  size_t grid_dim = (outer_dim + kThreadNumY - 1) / kThreadNumY;
  grid_dim = std::min<size_t>(grid_dim, dev_ctx.GetCUDAMaxGridDimSize()[0]);
  dim3 thread_dims(kThreadNumX, kThreadNumY);
  if (reverse) {
    InclusiveScanInnerDimCUDAKernel<
        T,
        BinaryOp,
        kThreadNumX,
        kThreadNumY,
        /*kReverse=*/true><<<grid_dim, thread_dims, 0, dev_ctx.stream()>>>(
        x, y, outer_dim, inner_dim, init, op);
  } else {
    InclusiveScanInnerDimCUDAKernel<
        T,
        BinaryOp,
        kThreadNumX,
        kThreadNumY,
        /*kReverse=*/false><<<grid_dim, thread_dims, 0, dev_ctx.stream()>>>(
        x, y, outer_dim, inner_dim, init, op);
  }
}

template <typename T, typename BinaryOp>
void InclusiveScan(const T *x,
                   T *y,
                   size_t outer_dim,
                   size_t mid_dim,
                   size_t inner_dim,
                   T init,
                   BinaryOp op,
                   bool reverse,
                   const phi::GPUContext &dev_ctx) {
  if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;

  if (outer_dim == 1 && inner_dim == 1) {
    if (reverse) {
      auto x_reverse_iter = MakeThrustReverseIterator(x + mid_dim);
      auto y_reverse_iter = MakeThrustReverseIterator(y + mid_dim);
      CubInclusiveScan(x_reverse_iter, y_reverse_iter, mid_dim, op, dev_ctx);
    } else {
      CubInclusiveScan(x, y, mid_dim, op, dev_ctx);
    }
  } else if (inner_dim != 1) {
    phi::funcs::ForRange<phi::GPUContext> for_range(dev_ctx,
                                                    outer_dim * inner_dim);
    if (reverse) {
      for_range(
          InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/true>(
              x, y, mid_dim, inner_dim, init, op));
    } else {
      for_range(
          InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/false>(
              x, y, mid_dim, inner_dim, init, op));
    }
  } else {
    InclusiveScanInnerDim<T, BinaryOp>(
        x, y, outer_dim, mid_dim, init, op, reverse, dev_ctx);
  }
}

}  // namespace funcs
}  // namespace phi