reduce_op.cu.h 25.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/fluid/framework/array.h"
33
#include "paddle/fluid/framework/op_registry.h"
34 35
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
36
#include "paddle/fluid/platform/cuda_device_function.h"
37

38 39 40
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512

41 42 43 44 45
namespace paddle {
namespace operators {
namespace detail {

// Post processing function for sum, max, min, prod, any
46
template <typename Tx, typename Ty = Tx>
47
struct IdentityFunctor {
48
  HOSTDEVICE explicit inline IdentityFunctor(int n) {}
49

50 51 52
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x);
  }
53 54 55 56 57
};

// Post processing function for mean
template <typename T>
struct DivideFunctor {
58
  HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
59

60
  HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
61 62 63 64 65 66 67 68 69 70 71 72 73 74

 private:
  T n_inv;
};

static inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

75 76 77
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
                                             const std::vector<int>& idx) {
78 79 80 81 82 83 84 85 86 87 88
  int n = static_cast<int>(idx.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[idx[i + 1]];
  }
  return strides;
}

#ifdef __HIPCC__
89
constexpr int kMaxThread = 256;
90
constexpr int kWarpSize = 64;
91
#else
92
constexpr int kMaxThread = 128;
93
constexpr int kWarpSize = 32;
94 95
#endif

96 97 98
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
  return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim);
99 100
}

101 102
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
  if (rank % 2 == 0) {
    PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
                      platform::errors::InvalidArgument(
                          "ReduceOp: invalid reduce rank. When rank = %d, "
                          "reduce_rank must be %d, but got %d.",
                          rank, rank / 2, reduce_rank));
  } else {
    auto lower_rank = (rank - 1) / 2;
    auto upper_rank = (rank + 1) / 2;
    PADDLE_ENFORCE_EQ(
        reduce_rank == lower_rank || reduce_rank == upper_rank, true,
        platform::errors::InvalidArgument(
            "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
            "must be %d or %d, but got %d.",
            rank, lower_rank, upper_rank, reduce_rank));
  }
}

121
// convert dims from vector to array
122
template <typename T, size_t ElementCount, typename VectorLikeType>
123
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
124 125 126 127 128 129 130 131
    const VectorLikeType& vec) {
  PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
                    platform::errors::InvalidArgument(
                        "Cub reduce Array: size not match. Received "
                        "vec.size() %d !=  ElementCount %d.",
                        vec.size(), ElementCount));
  size_t n = static_cast<size_t>(vec.size());
  paddle::framework::Array<T, ElementCount> ret;
132 133 134
  for (size_t i = 0; i < n; ++i) {
    ret[i] = vec[i];
  }
135 136 137 138 139
  return ret;
}

}  // namespace detail

140 141
using Tensor = framework::Tensor;

142
enum ReduceType {
143 144
  kReduceAll = 0x00,        // when reduce_rank == x_rank
  kReduceLastDim = 0x01,    // when reduce_dim[0] == x_dim.size() - 1;
145
  kReduceHigherDim = 0x02,  // ReduceFirstDim or reduceSecondDim
146
  kReduceAny = 0x03,        // when reduce_dim.size() > 1
147 148 149 150 151
};

// reduce config
template <typename Ty>
struct ReduceConfig {
152 153 154
  ReduceConfig(const std::vector<int>& origin_reduce_dims,
               const std::vector<int>& origin_x_dim)
      : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
155 156 157 158 159

  // get the parameters of reduceKernel
  void Run() {
    // step1: update the reduce_dim left_dim and x_dim
    SetReduceDim();
160

161 162
    // step2: get the strides of dim for reduceAny and reduceLastDim
    SetStrides();
163

164 165
    // step3: get the type of reduce
    SetReduceType();
166

167 168 169 170 171 172
    // step4: set the block and grid for launch kernel
    SetBlockDim();
  }

  // when should_reduce_again is true, we need malloc temp space for temp data
  void SetOutputData(Ty* y_data, const platform::Place& place,
173
                     framework::Tensor* tmp) {
174
    if (should_reduce_again) {
175
      output_data = tmp->mutable_data<Ty>(
176
          framework::make_ddim(
177
              {static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}),
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
          place);
    } else {
      output_data = y_data;
    }
  }

 private:
  // set reduce_dim, left_dim and update x_dim
  // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
  //     --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
  void SetReduceDim() {
    std::set<int> reduce_set;
    for (auto e : reduce_dims_origin) {
      auto pos = e >= 0 ? e : e + x_dim.size();
      reduce_set.insert(pos);
    }
194

195 196
    std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
    std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
197 198 199 200 201 202 203 204 205 206

    // update reduce_dim and x_dim
    std::vector<int> x_new_dim;

    reduce_dim.push_back(reduce_dim_temp[0]);
    x_new_dim.push_back(x_dim[0]);

    int idx_reduce = 1;
    int num = 0;

207
    if (reduce_dim_temp.size() > 1) {
208 209 210 211 212 213 214 215 216 217 218 219 220 221
      for (int i = 1; i < x_dim.size(); i++) {
        if ((idx_reduce < reduce_dim_temp.size()) &&
            (i == reduce_dim_temp[idx_reduce])) {
          int result =
              reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
          bool is_equal = ((result - num) == 1);
          if (is_equal) {
            x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
            num++;
          } else {
            reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
            x_new_dim.push_back(x_dim[i]);
          }
          idx_reduce++;
222
        } else {
223
          x_new_dim.push_back(x_dim[i]);
224 225 226
        }
      }
    } else {
227
      x_new_dim = x_dim;
228 229
    }

230 231 232 233 234
    // update x_dim
    x_dim = x_new_dim;
    std::vector<int>().swap(x_new_dim);

    std::vector<int> reduce_dim_new;
235 236 237 238 239
    int is_reduced = 0;
    for (auto e : reduce_dim) {
      is_reduced |= 1 << e;
    }

240 241
    std::vector<int>().swap(reduce_dim);

242 243
    for (int i = 0; i < x_dim.size(); i++) {
      if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
244
        x_new_dim.push_back(x_dim[i]);
245
        if ((is_reduced >> i) & 1)
246
          reduce_dim_new.push_back(x_new_dim.size() - 1);
247
      } else {
248
        x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
249 250 251
      }
    }

252 253
    x_dim = x_new_dim;
    reduce_dim = reduce_dim_new;
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

    int x_rank = static_cast<int>(x_dim.size());
    std::set<int> left_set;

    for (int i = 0; i < x_rank; ++i) {
      left_set.insert(i);
    }

    for (auto e : reduce_dim) {
      left_set.erase(e);
    }

    left_dim.assign(left_set.begin(), left_set.end());
  }

  // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
  // eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
  //     --SetStrides--> x_strides= [6,1], reduce_strides = [1],
  //     left_strides = [1]
  void SetStrides() {
    std::vector<int> idx_dim;
    for (int i = 0; i < x_dim.size(); i++) {
      idx_dim.push_back(i);
    }

279 280 281
    x_strides = detail::GetDimStrides(x_dim, idx_dim);
    reduce_strides = detail::GetDimStrides(x_dim, reduce_dim);
    left_strides = detail::GetDimStrides(x_dim, left_dim);
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];

    left_num = 1;
    if (left_dim.size()) {
      left_num = left_strides[0] * x_dim[left_dim[0]];
    }
  }

  // get the reduceType
  // eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
  //     x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
  //     x_dim = [8] reduce_dim = [0] --> reduceAll
  //     x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
  void SetReduceType() {
    int rank = x_dim.size();
    int reduce_rank = reduce_dim.size();
298 299
    bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
                           (left_num > REDUCE_SPLIT_BOUNDARY);
300 301 302 303 304 305

    if (rank == reduce_rank) {
      reduce_type = static_cast<int>(ReduceType::kReduceAll);

    } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
      reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
306 307 308

    } else if (reduce_rank == 1 &&
               ((rank == 2 && is_large_enough) || rank != 2)) {
309 310 311 312 313 314 315 316 317 318 319 320 321 322
      // ReduceFirstDim and reduceSecondDim
      reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);

    } else {
      reduce_type = static_cast<int>(ReduceType::kReduceAny);
    }
  }

  // set block and grid for launch kernel
  // for ReduceHigherDim: if block is enough -> splite reduce_num
  //                     else init block(32, 1) grid(block_num, 1)
  // for others: block(block_num, 1) , grid(left_num, 1)
  void SetBlockDim() {
    // init
323
    int block_num = detail::GetBlockDim(reduce_num);
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    should_reduce_again = false;

    dim3 block_dim(block_num, 1);
    dim3 grid_dim(left_num, 1);
    blocking_size = reduce_num;

    if (reduce_type == ReduceType::kReduceHigherDim) {
      int last_dim_num = x_dim.back();
      // update left_num
      int grid_z = left_num / last_dim_num;
      left_num = last_dim_num;

      block_dim.z = 1;
      grid_dim.z = grid_z;

      int device_id = platform::GetCurrentDeviceId();
      int max_mp = platform::GetCUDAMultiProcessors(device_id);
      int max_threads_per_mp =
          platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
      int max_threads = max_threads_per_mp * max_mp;

      // init
      int num_block = (max_threads / left_num);

348
      if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
        blocking_size = detail::GetLastPow2(reduce_num / num_block);

        if (blocking_size <= 1) {
          blocking_size = detail::GetLastPow2(sqrt(reduce_num));
        } else if (blocking_size * 2 < reduce_num) {
          blocking_size *= 2;
        }

        should_reduce_again = true;

        block_dim.x = 32;
        block_dim.y = 1;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size;

      } else {
        block_dim.x = 32;
        block_dim.y = 1;
        blocking_size = reduce_num;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = 1;
      }
    }

    block = block_dim;
    grid = grid_dim;
  }

 public:
  std::vector<int> reduce_dims_origin;
  std::vector<int> reduce_dim;
  std::vector<int> x_dim;
  std::vector<int> left_dim;
  std::vector<int> x_strides;
  std::vector<int> left_strides;
  std::vector<int> reduce_strides;

  int reduce_type;
  int reduce_num;
  int left_num;
  int blocking_size;
  bool should_reduce_again;

  Ty* output_data;

  dim3 block;
  dim3 grid;
};

398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
  using detail::kWarpSize;
  __shared__ T shared[kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
    int wid = threadIdx.x / kWarpSize;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[lane];
  }

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

444 445 446
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
447
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
448 449 450 451 452 453 454
__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
                                              ReduceOp reducer,
                                              TransformOp transformer, Ty init,
                                              int reduce_num) {
  int idx_x = blockIdx.x * reduce_num;
  int idx_y = threadIdx.x;
  Ty reduce_var = init;
455
  for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) {
456 457 458
    reduce_var =
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
  }
459 460
  __syncthreads();

461
  reduce_var = BlockReduce(reduce_var, reducer);
462 463

  if (threadIdx.x == 0) {
464
    y[blockIdx.x] = reduce_var;
465 466 467
  }
}

468 469 470 471 472
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
//     if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
//     else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
473 474 475 476 477 478 479 480 481 482 483 484 485 486
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
                                                ReduceOp reducer,
                                                TransformOp transformer,
                                                Ty init, int reduce_num,
                                                int left_num, int block_size) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int idy = blockIdx.y * block_size;

  Ty reduce_var = init;

  if (idx < left_num) {
    int loop = reduce_num - idy;
    loop = loop > block_size ? block_size : loop;
487

488 489
    for (int iy = 0; iy < loop; iy++) {
      int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
490
      reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
491
    }
492

493
    y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
494
        reduce_var;
495 496 497
  }
}

498 499 500
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
501
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
502
          int Rank, int ReduceRank>
503
__device__ __forceinline__ void ReduceAny(
504
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
    int reduce_num, paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
  int sub_index[Rank];
  int left_idx = blockIdx.x;
  for (int i = 0; i < Rank - ReduceRank; ++i) {
    sub_index[left_dim[i]] = left_idx / left_strides[i];
    left_idx %= left_strides[i];
  }

  int reduce_idx = threadIdx.x;
  for (int j = 0; j < ReduceRank; ++j) {
    sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
    reduce_idx %= reduce_strides[j];
  }

  int idx_x = 0;
524 525 526 527
  for (int k = 0; k < Rank; ++k) {
    idx_x += (sub_index[k] * x_strides[k]);
  }
  Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
528

529
  for (int i = threadIdx.x + blockDim.x; i < reduce_num; i += blockDim.x) {
530
    int reduce_idx = i;
531

532 533 534 535 536 537
    for (int j = 0; j < ReduceRank; ++j) {
      sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
      reduce_idx %= reduce_strides[j];
    }

    int idx_x = 0;
538 539 540 541 542 543
    for (int k = 0; k < Rank; ++k) {
      idx_x += (sub_index[k] * x_strides[k]);
    }

    reduce_var = static_cast<Ty>(
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
544 545 546
  }
  __syncthreads();

547
  reduce_var = BlockReduce(reduce_var, reducer);
548
  if (threadIdx.x == 0) {
549
    y[blockIdx.x] = reduce_var;
550 551 552
  }
}

553
// module function designed for global function
554
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
555
          int Rank, int ReduceRank>
556 557
__device__ __forceinline__ void ReduceModule(
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
558
    int reduce_num, int left_num, int blocking_size, int reduce_type,
559 560 561 562 563
    paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
564
  // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
565
  if (reduce_type == ReduceType::kReduceLastDim) {
566 567
    ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
                                                 init, reduce_num);
568

569
    // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
570
  } else if (reduce_type == ReduceType::kReduceHigherDim) {
571 572 573
    ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
        x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);

574
    // reduce_rank >= 2
575
  } else {
576
    ReduceAny<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>(
577
        x, y, reducer, transformer, reduce_num, x_strides, reduce_dim,
578 579 580 581 582
        reduce_strides, left_dim, left_strides);
  }
}

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
583
          int Rank, int ReduceRank>
584 585
__global__ void ReduceKernelFunction(
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
586
    int reduce_num, int left_num, int block_size, int reduce_type,
587 588 589 590 591
    paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
592
  ReduceModule<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>(
593 594 595
      x, y, reducer, transformer, init, reduce_num, left_num, block_size,
      reduce_type, x_strides, reduce_dim, reduce_strides, left_dim,
      left_strides);
596 597
}

598
template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank>
599 600 601 602 603
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
                               const ReduceOp& reducer, Ty init,
                               gpuStream_t stream, ReduceConfig<Ty> config) {
  using TransformOp = typename ReduceOp::Transformer;

604 605
  ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, Rank,
                       ReduceRank><<<config.grid, config.block, 0, stream>>>(
606 607
      x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
      config.reduce_num, config.left_num, config.blocking_size,
608 609 610 611 612
      config.reduce_type, detail::VectorToArray<int, Rank>(config.x_strides),
      detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
      detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
      detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
      detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
613 614 615 616 617

  if (config.should_reduce_again) {
    dim3 block(config.block.x, 1, 1);
    dim3 grid(config.grid.x, 1, config.grid.z);

618 619
    ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank,
                         ReduceRank><<<grid, block, 0, stream>>>(
620 621
        config.output_data, y_data, reducer,
        detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
622
        config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
623 624 625 626 627
        detail::VectorToArray<int, Rank>(config.x_strides),
        detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
        detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
        detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
        detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
628 629 630
  }
}

631
template <typename Tx, typename Ty, typename ReduceOp>
632 633 634
static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
                             const ReduceOp& reducer, Ty init,
                             gpuStream_t stream, ReduceConfig<Ty> config) {
635 636 637 638 639
  int reduce_rank = config.reduce_strides.size();
  int rank = config.x_strides.size();

#define CUB_RANK_CASE(i, ...)             \
  case i: {                               \
640
    constexpr auto Rank = i;              \
641 642 643
    switch (reduce_rank) { __VA_ARGS__; } \
  } break

644 645 646 647 648
#define CUB_REDUCE_RANK_CASE(i, ...)                        \
  case i: {                                                 \
    constexpr auto ReduceRank = i;                          \
    LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank>( \
        x_data, y_data, reducer, init, stream, config);     \
649 650
  } break

651
  detail::CheckReduceRank(reduce_rank, rank);
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
  switch (rank) {
    CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););

    CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););

    CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););

    CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););

    CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););

    CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););

    CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););

    CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
  }

#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
673 674 675 676 677 678

template <typename Tx, typename Ty,
          template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
                             std::vector<int> origin_reduce_dims,
                             gpuStream_t stream) {
679 680
  auto x_dim = framework::vectorize<int>(x.dims());
  auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
681
  config.Run();  // get the parameters of LaunchReduceKernel
682

683
  // after config.run()
684 685 686
  // SetOutputData for ReduceHigherDim when should_reduce_again is true,
  //   temp_output should be stored temp_data in output_data space or stored in
  //   y_data;
687
  framework::Tensor tmp;
688 689
  auto x_data = x.data<Tx>();
  auto y_data = y->mutable_data<Ty>(x.place());
690 691 692 693 694 695 696

  if (config.reduce_num == 1) {
    auto out_dims = y->dims();
    framework::TensorCopy(x, y->place(), y);
    y->Resize(out_dims);
    return;
  }
697 698 699

  config.SetOutputData(y_data, x.place(), &tmp);

700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
  using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
  auto reducer = ReduceOp<Tx, Ty>();
  // launch CUB::Reduce
  if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
    cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
        x_data, TransformOp(config.reduce_num));
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
    framework::Tensor tmp;
    auto* temp_storage = tmp.mutable_data<uint8_t>(
        framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
        x.place());
    cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
717

718 719 720
    return;
  }

721 722
  ReduceKernelImpl<Tx, Ty, ReduceOp<Tx, Ty>>(x_data, y_data, reducer,
                                             reducer.initial(), stream, config);
723 724
}

725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
template <typename Tx, template <typename, typename> class ReduceOp>
struct TensorReduceFunc {
  const framework::Tensor& x;
  framework::Tensor* y;
  std::vector<int> origin_reduce_dims;
  gpuStream_t stream;
  TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
                   std::vector<int> origin_reduce_dims, gpuStream_t stream)
      : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}

  template <typename Ty>
  void apply() const {
    TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
  }
};

741 742
}  // namespace operators
}  // namespace paddle