cub_reduce.h 13.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>

#include <cub/cub.cuh>  // NOLINT
#include "paddle/fluid/framework/tensor.h"
25
#include "paddle/fluid/framework/tensor_util.h"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

namespace paddle {
namespace operators {

namespace detail {
template <typename T, size_t ElementCount>
struct Array {
 public:
  HOSTDEVICE inline Array() {}

  HOSTDEVICE inline T& operator[](size_t index) { return data_[index]; }

  HOSTDEVICE inline const T& operator[](size_t index) const {
    return data_[index];
  }

  HOSTDEVICE constexpr inline size_t size() const { return ElementCount; }

  template <typename VectorLikeType>
  static inline Array<T, ElementCount> From(const VectorLikeType& vec) {
    PADDLE_ENFORCE_EQ(vec.size(), ElementCount, "size not match");
    size_t n = static_cast<size_t>(vec.size());
    Array<T, ElementCount> ret;
    for (size_t i = 0; i < n; ++i) ret[i] = vec[i];
    return ret;
  }

 private:
  T data_[ElementCount];
};

// reduce the last axis of 2d array
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim>
__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
                               TransformOp transformer, Ty init,
                               int reduce_num) {
  __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
  int idx_x = blockIdx.x * reduce_num;
  int idx_y = threadIdx.x;
  Ty reduce_var = init;
  for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim)
68 69
    reduce_var =
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
70
  __syncthreads();
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

  reduce_var =
      cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);

  if (threadIdx.x == 0) {
    y[blockIdx.x] = reduce_var;
  }
}

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim, int Rank, int ReduceRank>
__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
                             TransformOp transformer, Ty init, int reduce_num,
                             Array<int, Rank> x_strides,
                             Array<int, ReduceRank> reduce_dim,
                             Array<int, ReduceRank> reduce_strides,
                             Array<int, Rank - ReduceRank> left_dim,
                             Array<int, Rank - ReduceRank> left_strides) {
  __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
  Array<int, Rank> sub_index;
  int left_idx = blockIdx.x;
  for (int i = 0; i < Rank - ReduceRank; ++i) {
    sub_index[left_dim[i]] = left_idx / left_strides[i];
    left_idx %= left_strides[i];
  }

  int reduce_idx = threadIdx.x;
  for (int j = 0; j < ReduceRank; ++j) {
    sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
    reduce_idx %= reduce_strides[j];
  }

  int idx_x = 0;
  for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
  Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));

  for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
    int reduce_idx = i;
    for (int j = 0; j < ReduceRank; ++j) {
      sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
      reduce_idx %= reduce_strides[j];
    }

    int idx_x = 0;
    for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
116 117
    reduce_var = static_cast<Ty>(
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
118
  }
119
  __syncthreads();
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

  reduce_var =
      cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);

  if (threadIdx.x == 0) {
    y[blockIdx.x] = reduce_var;
  }
}

static inline std::vector<int> GetStrides(const std::vector<int>& dims) {
  int n = static_cast<int>(dims.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[i + 1];
  }
  return strides;
}

static inline std::vector<int> GetStrides(const std::vector<int>& dims,
                                          const std::vector<int>& idx) {
  int n = static_cast<int>(idx.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[idx[i + 1]];
  }
  return strides;
}

constexpr int kMaxBlockDim = 512;

static inline int GetDesiredBlockDim(int block_dim) {
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << static_cast<int>(std::log2(block_dim)));
}

160 161 162 163 164 165 166 167 168 169 170 171
static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
  if (rank % 2 == 0) {
    PADDLE_ENFORCE_EQ(reduce_rank, rank / 2);
  } else {
    auto lower_rank = (rank - 1) / 2;
    auto upper_rank = (rank + 1) / 2;
    PADDLE_ENFORCE(reduce_rank == lower_rank || reduce_rank == upper_rank,
                   "When rank = %d, reduce_rank must be %d or %d, but got %d",
                   rank, lower_rank, upper_rank, reduce_rank);
  }
}

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
          typename TransformOp>
static void TensorReduceImpl(
    const Tx* x_data, Ty* y_data, const platform::Place& place,
    const ReduceOp& reducer, const TransformOp& transformer, const Ty& init,
    int left_num, int reduce_num, const std::vector<int>& x_strides,
    const std::vector<int>& reduce_dim, const std::vector<int>& reduce_strides,
    const std::vector<int>& left_dim, const std::vector<int>& left_strides,
    cudaStream_t stream) {
#define CUB_RANK_CASE(i, ...)             \
  case i: {                               \
    constexpr auto kRank = i;             \
    switch (reduce_rank) { __VA_ARGS__; } \
  } break

#define CUB_REDUCE_RANK_CASE(i, ...)                              \
  case i: {                                                       \
    constexpr auto kReduceRank = i;                               \
    ReduceKernel<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank,  \
                 kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
        x_data, y_data, reducer, transformer, init, reduce_num,   \
        Array<int, kRank>::From(x_strides),                       \
        Array<int, kReduceRank>::From(reduce_dim),                \
        Array<int, kReduceRank>::From(reduce_strides),            \
        Array<int, kRank - kReduceRank>::From(left_dim),          \
        Array<int, kRank - kReduceRank>::From(left_strides));     \
  } break

  int rank = x_strides.size();
  int reduce_rank = reduce_strides.size();
  if (rank == reduce_rank) {
    cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
        x_data, transformer);
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
                              reduce_num, reducer, init, stream);
    framework::Tensor tmp;
    auto* temp_storage = tmp.mutable_data<uint8_t>(
        framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
        place);
    cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
                              reduce_num, reducer, init, stream);
    return;
  }
  if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
    ReduceKernel2D<Tx, Ty, ReduceOp, TransformOp,
                   BlockDim><<<left_num, BlockDim, 0, stream>>>(
        x_data, y_data, reducer, transformer, init, reduce_num);
    return;
  }
  /*
  if (rank == 3 && reduce_rank == 1 && reduce_dim[0] == 1) {
    // TODO(liangdun): we can optimize 3d case which the 2nd axis is reduced.
    // Currently, it is handled by code below, but inefficient
    return;
  }
  */

230 231 232 233 234 235 236 237 238 239 240 241 242 243
  /**
   * Since we have combined the adjacent reduce dimensions inside TensorReduce,
   * The reduce ranks and non-reduce ranks must be interleaving. That is to say,
   * the rank of Tensor must be `1010...` or `0101...` where 1 represents that
   * the dimension is about to be reduced.
   *
   * Therefore,
   * If rank is odd, only need to switch-case (rank - 1)/2 and (rank + 1)/2.
   * If rank is even, only need to switch-case rank/2.
   *
   * The total switch-case numbers reduce from 1+2+3+...+8=36 to (1+2)*4=12,
   * it would speed up compiling and make the binary size lower.
   */
  CheckReduceRankIsValid(reduce_rank, rank);
244 245 246 247 248
  switch (rank) {
    CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););

    CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););

249
    CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););
250

251
    CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););
252

253
    CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););
254

255
    CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););
256

257
    CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););
258

259
    CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
260 261 262 263 264 265 266 267 268 269 270 271 272
  }

#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}

}  // namespace detail

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
void TensorReduce(const framework::Tensor& x, framework::Tensor* y,
                  std::vector<int> origin_reduce_dims, const Ty& init,
                  const ReduceOp& reducer, const TransformOp& transformer,
                  cudaStream_t stream) {
273
  auto x_dim = framework::vectorize<int>(x.dims());
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
  std::vector<int> new_x_dim, new_reduce_dims;
  int is_reduced = 0;
  for (auto e : origin_reduce_dims) {
    auto pos = e >= 0 ? e : e + x_dim.size();
    is_reduced |= 1 << e;
  }
  for (int i = 0; i < x_dim.size(); i++) {
    if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
      new_x_dim.push_back(x_dim[i]);
      if ((is_reduced >> i) & 1)
        new_reduce_dims.push_back(new_x_dim.size() - 1);
    } else {
      new_x_dim[new_x_dim.size() - 1] *= x_dim[i];
    }
  }
  x_dim = new_x_dim;
  origin_reduce_dims = new_reduce_dims;
  int x_rank = static_cast<int>(x_dim.size());
  std::set<int> left_set, reduce_set;
  for (int i = 0; i < x_rank; ++i) left_set.insert(i);

  for (auto e : origin_reduce_dims) {
    left_set.erase(e);
    reduce_set.insert(e);
  }

  std::vector<int> reduce_dim(reduce_set.begin(), reduce_set.end());
  std::vector<int> left_dim(left_set.begin(), left_set.end());

  std::vector<int> x_strides = detail::GetStrides(x_dim);
  std::vector<int> reduce_strides = detail::GetStrides(x_dim, reduce_dim);
  std::vector<int> left_strides = detail::GetStrides(x_dim, left_dim);
  int reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];
  int left_num = 1;
  if (left_dim.size()) left_num = left_strides[0] * x_dim[left_dim[0]];

  std::vector<int> y_dim(left_dim.size());
  for (int i = 0; i < left_dim.size(); ++i) {
    y_dim[i] = x_dim[left_dim[i]];
  }
  auto x_data = x.data<Tx>();
  auto y_data = y->mutable_data<Ty>(x.place());
316 317 318 319 320 321
  if (reduce_num == 1) {
    auto out_dims = y->dims();
    framework::TensorCopy(x, y->place(), y);
    y->Resize(out_dims);
    return;
  }
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345

#define CUB_BLOCK_DIM_CASE(block_dim)                                    \
  case block_dim: {                                                      \
    constexpr auto kBlockDim = block_dim;                                \
    detail::TensorReduceImpl<Tx, Ty, block_dim, ReduceOp, TransformOp>(  \
        x_data, y_data, x.place(), reducer, transformer, init, left_num, \
        reduce_num, x_strides, reduce_dim, reduce_strides, left_dim,     \
        left_strides, stream);                                           \
  } break

  switch (detail::GetDesiredBlockDim(reduce_num)) {
    CUB_BLOCK_DIM_CASE(512);
    CUB_BLOCK_DIM_CASE(256);
    CUB_BLOCK_DIM_CASE(128);
    CUB_BLOCK_DIM_CASE(64);
    CUB_BLOCK_DIM_CASE(32);
    CUB_BLOCK_DIM_CASE(16);
    CUB_BLOCK_DIM_CASE(8);
    CUB_BLOCK_DIM_CASE(4);
    CUB_BLOCK_DIM_CASE(2);
  }
#undef CUB_BLOCK_DIM_CASE
}

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
template <typename Tx, typename ReduceOp, typename TransformOp>
struct TensorReduceFunctor {
  const framework::Tensor& x;
  framework::Tensor* y;
  std::vector<int> origin_reduce_dims;
  const double& init;
  const ReduceOp& reducer;
  const TransformOp& transformer;
  cudaStream_t stream;
  TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y,
                      std::vector<int> origin_reduce_dims, const double& init,
                      const ReduceOp& reducer, const TransformOp& transformer,
                      cudaStream_t stream)
      : x(x),
        y(y),
        origin_reduce_dims(origin_reduce_dims),
        init(init),
        reducer(reducer),
        transformer(transformer),
        stream(stream) {}

  template <typename Ty>

  void apply() const {
    const Ty& init_cast = static_cast<Ty>(init);
    TensorReduce<Tx, Ty, ReduceOp, TransformOp>(
        x, y, origin_reduce_dims, init_cast, reducer, transformer, stream);
  }
};

376 377
}  // namespace operators
}  // namespace paddle