reduce_op.cu.h 28.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/fluid/framework/array.h"
33
#include "paddle/fluid/framework/op_registry.h"
34 35
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
36
#include "paddle/fluid/platform/cuda_device_function.h"
37
#include "paddle/fluid/platform/fast_divmod.h"
38

39 40
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
41
#define REDUCE_VEC_SIZE 4
42

43 44 45 46 47
namespace paddle {
namespace operators {
namespace detail {

// Post processing function for sum, max, min, prod, any
48
template <typename Tx, typename Ty = Tx>
49
struct IdentityFunctor {
50
  HOSTDEVICE explicit inline IdentityFunctor(int n) {}
51

52 53 54
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x);
  }
55 56 57 58 59
};

// Post processing function for mean
template <typename T>
struct DivideFunctor {
60
  HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
61

62
  HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
63 64 65 66 67 68 69 70 71 72 73 74 75 76

 private:
  T n_inv;
};

static inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

77 78
static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; }

79 80 81
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
                                             const std::vector<int>& idx) {
82 83 84 85 86 87 88 89 90 91 92
  int n = static_cast<int>(idx.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[idx[i + 1]];
  }
  return strides;
}

#ifdef __HIPCC__
93
constexpr int kMaxThread = 256;
94
constexpr int kWarpSize = 64;
95
#else
96
constexpr int kMaxThread = 128;
97
constexpr int kWarpSize = 32;
98 99
#endif

100 101 102
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
  return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim);
103 104
}

105 106
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  if (rank % 2 == 0) {
    PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
                      platform::errors::InvalidArgument(
                          "ReduceOp: invalid reduce rank. When rank = %d, "
                          "reduce_rank must be %d, but got %d.",
                          rank, rank / 2, reduce_rank));
  } else {
    auto lower_rank = (rank - 1) / 2;
    auto upper_rank = (rank + 1) / 2;
    PADDLE_ENFORCE_EQ(
        reduce_rank == lower_rank || reduce_rank == upper_rank, true,
        platform::errors::InvalidArgument(
            "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
            "must be %d or %d, but got %d.",
            rank, lower_rank, upper_rank, reduce_rank));
  }
}

125
// convert dims from vector to array
126
template <typename T, size_t ElementCount, typename VectorLikeType>
127
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
128
    const VectorLikeType& vec) {
129
  PADDLE_ENFORCE_LE(vec.size(), ElementCount,
130 131
                    platform::errors::InvalidArgument(
                        "Cub reduce Array: size not match. Received "
132
                        "vec.size() %d > ElementCount %d.",
133 134 135
                        vec.size(), ElementCount));
  size_t n = static_cast<size_t>(vec.size());
  paddle::framework::Array<T, ElementCount> ret;
136 137 138
  for (size_t i = 0; i < n; ++i) {
    ret[i] = vec[i];
  }
139 140 141 142 143
  return ret;
}

}  // namespace detail

144
using Tensor = framework::Tensor;
145
constexpr int kMaxRank = framework::DDim::kMaxRank;
146

147
enum ReduceType {
148 149
  kReduceAll = 0x00,        // when reduce_rank == x_rank
  kReduceLastDim = 0x01,    // when reduce_dim[0] == x_dim.size() - 1;
150
  kReduceHigherDim = 0x02,  // ReduceFirstDim or reduceSecondDim
151
  kReduceAny = 0x03,        // when reduce_dim.size() > 1
152 153
};

154 155 156 157 158 159 160
struct IndexCalculator {
  IndexCalculator(int dim, const std::vector<int>& cal_dims,
                  const std::vector<int>& cal_strides,
                  const std::vector<int>& full_strides)
      : dim(dim) {
    dims = detail::VectorToArray<int, kMaxRank>(cal_dims);
    strides = detail::VectorToArray<int, kMaxRank>(full_strides);
161
    std::vector<platform::FastDivMod> cal_divmoders;
162 163
    // fast divmod
    for (auto i : cal_strides) {
164
      cal_divmoders.push_back(platform::FastDivMod(i));
165
    }
166 167
    divmoders =
        detail::VectorToArray<platform::FastDivMod, kMaxRank>(cal_divmoders);
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  }

  __device__ inline int Get(int offset) const {
    int index = 0;
#pragma unroll
    for (int i = 0; i < kMaxRank; ++i) {
      if (i == dim) {
        break;
      }
      auto divmod = divmoders[i].Divmod(offset);
      index += (divmod.val[0] * strides[dims[i]]);
      offset = divmod.val[1];
    }
    return index;
  }

  int dim;
  framework::Array<int, kMaxRank> dims;
  framework::Array<int, kMaxRank> strides;
187
  framework::Array<platform::FastDivMod, kMaxRank> divmoders;
188 189
};

190 191 192
// reduce config
template <typename Ty>
struct ReduceConfig {
193 194 195
  ReduceConfig(const std::vector<int>& origin_reduce_dims,
               const std::vector<int>& origin_x_dim)
      : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
196 197 198 199 200

  // get the parameters of reduceKernel
  void Run() {
    // step1: update the reduce_dim left_dim and x_dim
    SetReduceDim();
201

202 203
    // step2: get the strides of dim for reduceAny and reduceLastDim
    SetStrides();
204

205 206
    // step3: get the type of reduce
    SetReduceType();
207

208 209 210 211 212 213
    // step4: set the block and grid for launch kernel
    SetBlockDim();
  }

  // when should_reduce_again is true, we need malloc temp space for temp data
  void SetOutputData(Ty* y_data, const platform::Place& place,
214
                     framework::Tensor* tmp) {
215
    if (should_reduce_again) {
216
      output_data = tmp->mutable_data<Ty>(
217
          framework::make_ddim(
218
              {static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}),
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
          place);
    } else {
      output_data = y_data;
    }
  }

 private:
  // set reduce_dim, left_dim and update x_dim
  // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
  //     --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
  void SetReduceDim() {
    std::set<int> reduce_set;
    for (auto e : reduce_dims_origin) {
      auto pos = e >= 0 ? e : e + x_dim.size();
      reduce_set.insert(pos);
    }
235

236 237
    std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
    std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
238 239 240 241 242 243 244 245 246 247

    // update reduce_dim and x_dim
    std::vector<int> x_new_dim;

    reduce_dim.push_back(reduce_dim_temp[0]);
    x_new_dim.push_back(x_dim[0]);

    int idx_reduce = 1;
    int num = 0;

248
    if (reduce_dim_temp.size() > 1) {
249 250 251 252 253 254 255 256 257 258 259 260 261 262
      for (int i = 1; i < x_dim.size(); i++) {
        if ((idx_reduce < reduce_dim_temp.size()) &&
            (i == reduce_dim_temp[idx_reduce])) {
          int result =
              reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
          bool is_equal = ((result - num) == 1);
          if (is_equal) {
            x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
            num++;
          } else {
            reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
            x_new_dim.push_back(x_dim[i]);
          }
          idx_reduce++;
263
        } else {
264
          x_new_dim.push_back(x_dim[i]);
265 266 267
        }
      }
    } else {
268
      x_new_dim = x_dim;
269 270
    }

271 272 273 274 275
    // update x_dim
    x_dim = x_new_dim;
    std::vector<int>().swap(x_new_dim);

    std::vector<int> reduce_dim_new;
276 277 278 279 280
    int is_reduced = 0;
    for (auto e : reduce_dim) {
      is_reduced |= 1 << e;
    }

281 282
    std::vector<int>().swap(reduce_dim);

283 284
    for (int i = 0; i < x_dim.size(); i++) {
      if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
285
        x_new_dim.push_back(x_dim[i]);
286
        if ((is_reduced >> i) & 1)
287
          reduce_dim_new.push_back(x_new_dim.size() - 1);
288
      } else {
289
        x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
290 291 292
      }
    }

293 294
    x_dim = x_new_dim;
    reduce_dim = reduce_dim_new;
295 296 297 298 299 300 301 302 303 304 305 306 307

    int x_rank = static_cast<int>(x_dim.size());
    std::set<int> left_set;

    for (int i = 0; i < x_rank; ++i) {
      left_set.insert(i);
    }

    for (auto e : reduce_dim) {
      left_set.erase(e);
    }

    left_dim.assign(left_set.begin(), left_set.end());
308 309 310

    // if the last dim gets involved in reduction
    reduce_lastdim = (reduce_dim.back() == x_dim.size() - 1);
311 312 313 314 315 316 317 318 319 320 321 322
  }

  // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
  // eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
  //     --SetStrides--> x_strides= [6,1], reduce_strides = [1],
  //     left_strides = [1]
  void SetStrides() {
    std::vector<int> idx_dim;
    for (int i = 0; i < x_dim.size(); i++) {
      idx_dim.push_back(i);
    }

323 324 325
    x_strides = detail::GetDimStrides(x_dim, idx_dim);
    reduce_strides = detail::GetDimStrides(x_dim, reduce_dim);
    left_strides = detail::GetDimStrides(x_dim, left_dim);
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
    reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];

    left_num = 1;
    if (left_dim.size()) {
      left_num = left_strides[0] * x_dim[left_dim[0]];
    }
  }

  // get the reduceType
  // eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
  //     x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
  //     x_dim = [8] reduce_dim = [0] --> reduceAll
  //     x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
  void SetReduceType() {
    int rank = x_dim.size();
    int reduce_rank = reduce_dim.size();
342 343
    bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
                           (left_num > REDUCE_SPLIT_BOUNDARY);
344 345 346 347 348

    if (rank == reduce_rank) {
      reduce_type = static_cast<int>(ReduceType::kReduceAll);
    } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
      reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
349 350
    } else if (reduce_rank == 1 &&
               ((rank == 2 && is_large_enough) || rank != 2)) {
351 352 353 354 355 356 357
      // ReduceFirstDim and reduceSecondDim
      reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
    } else {
      reduce_type = static_cast<int>(ReduceType::kReduceAny);
    }
  }

358 359 360 361 362 363
  void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) {
    constexpr int min_reduce_num_per_thread = 16;
    constexpr int max_reduce_num_per_thread = 256;
    constexpr int max_num_threads = detail::kMaxThread;

    // set block size.
364 365 366 367 368 369 370 371
    // 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
    //    will process the reduction for one output.
    //    The number of output for one block is blockDim.y;
    // 2. If reduce_lastdim == false, different threadIdx.x will process
    //    different reduction and gets the output separately. If it is
    //    necessary, it should reduce in block y.
    //    The number of output for one block is blockDim.x;
    int block_x, block_y;
372 373
    int grid_num, reduce_num_per_thread;
    if (reduce_lastdim) {
374 375 376 377 378 379 380
      block_x = detail::GetBlockDim(reduce_num);
      block_y = detail::GetBlockDim(left_num);
      block_dim->x = block_x;
      block_dim->y =
          std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
      grid_num = detail::AlignUp(left_num, block_dim->y);
      reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x);
381
    } else {
382 383
      block_x = detail::GetBlockDim(left_num);
      block_y = detail::GetBlockDim(reduce_num);
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
      block_dim->x = std::min(block_x, 32);
      block_dim->y =
          std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
      block_dim->x =
          std::min(block_x, static_cast<int>(max_num_threads / block_dim->y));
      grid_num = detail::AlignUp(left_num, block_dim->x);
      reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->y);
    }
    int device_id = platform::GetCurrentDeviceId();
    int max_mp = platform::GetCUDAMultiProcessors(device_id);
    int max_threads_per_mp =
        platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
    int max_threads = max_threads_per_mp * max_mp;
    int num_threads = block_dim->x * block_dim->y;
    int max_num_blocks = max_threads / num_threads;

    // set grid size.
    // Whether to set grid.y larger than 1, there are 3 following rules:
    // 1. The number that each thread process should no less than
    //    min_reduce_num_per_threadbut no more than max_reduce_num_per_thread;
    // 2. It should maximize the utilization of SM.
    // So we choose the minimum between input_split_num_1 and input_split_num_3
    // to make each thread process as mush data as possible. Meanwhile,
    // the number cannot be larger than max_reduce_num_per_thread, so we
    // choose the maximum between the result above and input_split_num_2.
    int input_split_num_1 =
        detail::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread);
    int input_split_num_2 =
        detail::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread);
    int input_split_num_3 = detail::AlignUp(max_num_blocks, grid_num);

    grid_dim->x = grid_num;
    grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3),
                           input_split_num_2);
    // if grid.y > 1, we need launch reduce kernel again.
    if (grid_dim->y > 1) {
      should_reduce_again = true;
    }
  }

424 425 426 427 428 429
  // set block and grid for launch kernel
  // for ReduceHigherDim: if block is enough -> splite reduce_num
  //                     else init block(32, 1) grid(block_num, 1)
  // for others: block(block_num, 1) , grid(left_num, 1)
  void SetBlockDim() {
    // init
430
    int block_num = detail::GetBlockDim(reduce_num);
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
    should_reduce_again = false;

    dim3 block_dim(block_num, 1);
    dim3 grid_dim(left_num, 1);
    blocking_size = reduce_num;

    if (reduce_type == ReduceType::kReduceHigherDim) {
      int last_dim_num = x_dim.back();
      // update left_num
      int grid_z = left_num / last_dim_num;
      left_num = last_dim_num;

      block_dim.z = 1;
      grid_dim.z = grid_z;

      int device_id = platform::GetCurrentDeviceId();
      int max_mp = platform::GetCUDAMultiProcessors(device_id);
      int max_threads_per_mp =
          platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
      int max_threads = max_threads_per_mp * max_mp;

      // init
      int num_block = (max_threads / left_num);

455
      if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        blocking_size = detail::GetLastPow2(reduce_num / num_block);

        if (blocking_size <= 1) {
          blocking_size = detail::GetLastPow2(sqrt(reduce_num));
        } else if (blocking_size * 2 < reduce_num) {
          blocking_size *= 2;
        }

        should_reduce_again = true;

        block_dim.x = 32;
        block_dim.y = 1;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size;

      } else {
        block_dim.x = 32;
        block_dim.y = 1;
        blocking_size = reduce_num;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = 1;
      }
478
    } else {
479
      SetBlockDimForReduceAny(&block_dim, &grid_dim);
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
    }

    block = block_dim;
    grid = grid_dim;
  }

 public:
  std::vector<int> reduce_dims_origin;
  std::vector<int> reduce_dim;
  std::vector<int> x_dim;
  std::vector<int> left_dim;
  std::vector<int> x_strides;
  std::vector<int> left_strides;
  std::vector<int> reduce_strides;

  int reduce_type;
  int reduce_num;
  int left_num;
  int blocking_size;
  bool should_reduce_again;
500
  bool reduce_lastdim;
501 502 503 504 505 506 507

  Ty* output_data;

  dim3 block;
  dim3 grid;
};

508 509 510 511
static __device__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}

512
template <typename T, typename ReduceOp>
513
static __device__ T WarpReduce(T val, ReduceOp reducer) {
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */
template <typename T, typename ReduceOp>
533
static __device__ T BlockXReduce(T val, ReduceOp reducer) {
534
  using detail::kWarpSize;
535
  __shared__ T shared[2 * kWarpSize];
536 537 538 539
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
540 541 542
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int wid = tid / kWarpSize;
    int bid = threadIdx.y;
543 544 545 546 547
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
548
    val = shared[bid * block_dim_x + lane];
549 550 551 552 553 554 555 556 557 558 559
  }

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
template <typename T, typename ReduceOp>
static __device__ T BlockYReduce(T val, ReduceOp reducer) {
  __shared__ T shared_memory[detail::kMaxThread];
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
  return val;
}

575 576 577 578 579
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
//     if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
//     else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
580
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
581 582 583
__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
                                TransformOp transformer, Ty init,
                                int reduce_num, int left_num, int block_size) {
584 585 586 587 588 589 590 591
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int idy = blockIdx.y * block_size;

  Ty reduce_var = init;

  if (idx < left_num) {
    int loop = reduce_num - idy;
    loop = loop > block_size ? block_size : loop;
592

593 594
    for (int iy = 0; iy < loop; iy++) {
      int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
595
      reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
596
    }
597

598
    y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
599
        reduce_var;
600 601 602
  }
}

603
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
604 605
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
606 607
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          typename ReduceIndexCal, typename LeftIndexCal>
608 609 610
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
                          TransformOp transformer, Ty init, int reduce_num,
                          int left_num, bool reduce_lastdim,
611 612
                          ReduceIndexCal reduce_index_calculator,
                          LeftIndexCal left_index_calculator) {
613 614 615 616
  int input_idx, left_idx, stride;
  // the last dim gets involved in reduction
  if (reduce_lastdim) {
    input_idx = blockIdx.y * blockDim.x + threadIdx.x;
617
    left_idx = blockIdx.x * blockDim.y + threadIdx.y;
618 619 620 621 622
    stride = gridDim.y * blockDim.x;
  } else {
    input_idx = blockIdx.y * blockDim.y + threadIdx.y;
    left_idx = blockIdx.x * blockDim.x + threadIdx.x;
    stride = gridDim.y * blockDim.y;
623
  }
624
  // calculate the offset, means the addr where each thread really start.
625
  int input_offset = left_index_calculator(left_idx);
626 627
  const Tx* input = x + input_offset;
  Ty reduce_var = init;
628

629 630 631 632 633 634 635 636 637
  // 1. reduce for each thread
  if (left_idx < left_num) {
    // load REDUCE_VEC_SIZE data once, and then compute
    Tx input_reg[REDUCE_VEC_SIZE];
    int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride;
    while (input_idx < bound) {
#pragma unroll
      for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
        int reduce_idx = input_idx + i * stride;
638
        int idx_x = reduce_index_calculator(reduce_idx);
639 640 641 642 643 644 645
        input_reg[i] = input[idx_x];
      }
#pragma unroll
      for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
        reduce_var = reducer(reduce_var, transformer(input_reg[i]));
      }
      input_idx += REDUCE_VEC_SIZE * stride;
646 647
    }

648 649 650 651 652 653 654 655
    // deal with the remain part
    int input_idx_tmp = input_idx;
#pragma unroll
    for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
      if (input_idx >= reduce_num) {
        break;
      }
      int reduce_idx = input_idx;
656
      int idx_x = reduce_index_calculator(reduce_idx);
657 658
      input_reg[i] = input[idx_x];
      input_idx += stride;
659
    }
660 661 662 663 664 665 666 667 668 669
    input_idx = input_idx_tmp;
#pragma unroll
    for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
      if (input_idx >= reduce_num) {
        break;
      }
      reduce_var = reducer(reduce_var, transformer(input_reg[i]));
      input_idx += stride;
    }
  }
670

671
  // 2. reduce in block y
672
  if (!reduce_lastdim && blockDim.y > 1) {
673
    reduce_var = BlockYReduce(reduce_var, reducer);
674 675 676
  }
  __syncthreads();

677 678 679
  if (reduce_lastdim) {
    // 3. reduce in block x
    reduce_var = BlockXReduce(reduce_var, reducer);
680 681
    if (left_idx < left_num && threadIdx.x == 0) {
      y[blockIdx.y * left_num + left_idx] = reduce_var;
682 683 684 685 686
    }
  } else {
    if (left_idx < left_num && threadIdx.y == 0) {
      y[blockIdx.y * left_num + left_idx] = reduce_var;
    }
687 688 689
  }
}

690
// module function designed for global function
691 692 693 694 695 696 697
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
                             TransformOp transformer, Ty init, int reduce_num,
                             int left_num, int blocking_size, int reduce_type,
                             bool reduce_lastdim,
                             const IndexCalculator& reduce_index_calculator,
                             const IndexCalculator& left_index_calculator) {
698
  if (reduce_type == ReduceType::kReduceLastDim) {
699 700 701 702
    ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
        x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
        [&](int idx) { return idx; },
        [&](int idx) { return idx * reduce_num; });
703

704
    // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
705
  } else if (reduce_type == ReduceType::kReduceHigherDim) {
706 707 708
    ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
        x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);

709
    // reduce_rank >= 2
710
  } else {
711 712
    ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
        x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
713 714
        [&](int idx) { return reduce_index_calculator.Get(idx); },
        [&](int idx) { return left_index_calculator.Get(idx); });
715 716 717
  }
}

718 719 720 721 722 723 724 725 726 727 728 729
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
                                     TransformOp transformer, Ty init,
                                     int reduce_num, int left_num,
                                     int blocking_size, int reduce_type,
                                     bool reduce_lastdim,
                                     IndexCalculator reduce_index_calculator,
                                     IndexCalculator left_index_calculator) {
  ReduceModule<Tx, Ty, ReduceOp, TransformOp>(
      x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
      reduce_type, reduce_lastdim, reduce_index_calculator,
      left_index_calculator);
730 731
}

732
template <typename Tx, typename Ty, typename ReduceOp>
733 734 735 736 737
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
                               const ReduceOp& reducer, Ty init,
                               gpuStream_t stream, ReduceConfig<Ty> config) {
  using TransformOp = typename ReduceOp::Transformer;

738 739 740 741 742 743 744 745 746
  int reduce_rank = config.reduce_strides.size();
  int left_rank = config.left_strides.size();
  auto reduce_index_calculator = IndexCalculator(
      reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides);
  auto left_index_calculator = IndexCalculator(
      left_rank, config.left_dim, config.left_strides, config.x_strides);

  ReduceKernelFunction<Tx, Ty, ReduceOp,
                       TransformOp><<<config.grid, config.block, 0, stream>>>(
747 748
      x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
      config.reduce_num, config.left_num, config.blocking_size,
749 750
      config.reduce_type, config.reduce_lastdim, reduce_index_calculator,
      left_index_calculator);
751 752

  if (config.should_reduce_again) {
753 754 755 756 757 758 759 760 761
    dim3 block;
    dim3 grid;
    if (config.reduce_lastdim) {
      block = dim3(32, 1, 1);
      grid = dim3(detail::AlignUp(config.left_num, 32), 1, 1);
    } else {
      block = dim3(config.block.x, 1, 1);
      grid = dim3(config.grid.x, 1, config.grid.z);
    }
762

763 764
    ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<
                                               Ty>><<<grid, block, 0, stream>>>(
765 766
        config.output_data, y_data, reducer,
        detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
767
        config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
768
        config.reduce_lastdim, reduce_index_calculator, left_index_calculator);
769 770 771
  }
}

772 773 774 775 776
template <typename Tx, typename Ty,
          template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
                             std::vector<int> origin_reduce_dims,
                             gpuStream_t stream) {
777 778
  auto x_dim = framework::vectorize<int>(x.dims());
  auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
779
  config.Run();  // get the parameters of LaunchReduceKernel
780

781
  // after config.run()
782
  // SetOutputData for ReduceHigherDim when should_reduce_again is true,
783 784
  // temp_output should be stored temp_data in output_data space or stored in
  // y_data;
785
  framework::Tensor tmp;
786 787
  auto x_data = x.data<Tx>();
  auto y_data = y->mutable_data<Ty>(x.place());
788 789 790 791 792 793 794

  if (config.reduce_num == 1) {
    auto out_dims = y->dims();
    framework::TensorCopy(x, y->place(), y);
    y->Resize(out_dims);
    return;
  }
795 796 797

  config.SetOutputData(y_data, x.place(), &tmp);

798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814
  using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
  auto reducer = ReduceOp<Tx, Ty>();
  // launch CUB::Reduce
  if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
    cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
        x_data, TransformOp(config.reduce_num));
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
    framework::Tensor tmp;
    auto* temp_storage = tmp.mutable_data<uint8_t>(
        framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
        x.place());
    cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
815

816 817 818
    return;
  }

819 820
  LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>(
      x_data, y_data, reducer, reducer.initial(), stream, config);
821 822
}

823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838
template <typename Tx, template <typename, typename> class ReduceOp>
struct TensorReduceFunc {
  const framework::Tensor& x;
  framework::Tensor* y;
  std::vector<int> origin_reduce_dims;
  gpuStream_t stream;
  TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
                   std::vector<int> origin_reduce_dims, gpuStream_t stream)
      : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}

  template <typename Ty>
  void apply() const {
    TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
  }
};

839 840
}  // namespace operators
}  // namespace paddle