reduce_op.cu.h 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/fluid/framework/array.h"
33
#include "paddle/fluid/framework/op_registry.h"
34 35 36
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"

37 38 39
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512

40 41 42 43 44
namespace paddle {
namespace operators {
namespace detail {

// Post processing function for sum, max, min, prod, any
45
template <typename Tx, typename Ty = Tx>
46
struct IdentityFunctor {
47
  HOSTDEVICE explicit inline IdentityFunctor(int n) {}
48

49 50 51
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x);
  }
52 53 54 55 56
};

// Post processing function for mean
template <typename T>
struct DivideFunctor {
57
  HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
58

59
  HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
60 61 62 63 64

 private:
  T n_inv;
};

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
static inline std::vector<int> GetReduceDim(const std::vector<int>& dims,
                                            int dim_size, bool reduce_all) {
  std::vector<int> reduce_dims;
  if (reduce_all) {
    reduce_dims.resize(dim_size);
    for (int i = 0; i < reduce_dims.size(); ++i) {
      reduce_dims[i] = i;
    }
  } else {
    for (auto e : dims) {
      PADDLE_ENFORCE_LT(e, dim_size,
                        paddle::platform::errors::InvalidArgument(
                            "ReduceOp: invalid axis, when x_dims is %d, "
                            "axis[i] should less than x_dims, but got %d.",
                            dim_size, e));
      reduce_dims.push_back(e >= 0 ? e : e + dim_size);
    }
  }
  return reduce_dims;
}

86 87 88 89 90 91 92 93 94
static inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

95 96 97
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
                                             const std::vector<int>& idx) {
98 99 100 101 102 103 104 105 106 107 108
  int n = static_cast<int>(idx.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[idx[i + 1]];
  }
  return strides;
}

#ifdef __HIPCC__
109
constexpr int kMaxThread = 256;
110
#else
111
constexpr int kMaxThread = 128;
112 113
#endif

114 115 116
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
  return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim);
117 118
}

119 120
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  if (rank % 2 == 0) {
    PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
                      platform::errors::InvalidArgument(
                          "ReduceOp: invalid reduce rank. When rank = %d, "
                          "reduce_rank must be %d, but got %d.",
                          rank, rank / 2, reduce_rank));
  } else {
    auto lower_rank = (rank - 1) / 2;
    auto upper_rank = (rank + 1) / 2;
    PADDLE_ENFORCE_EQ(
        reduce_rank == lower_rank || reduce_rank == upper_rank, true,
        platform::errors::InvalidArgument(
            "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
            "must be %d or %d, but got %d.",
            rank, lower_rank, upper_rank, reduce_rank));
  }
}

139
// convert dims from vector to array
140
template <typename T, size_t ElementCount, typename VectorLikeType>
141
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
142 143 144 145 146 147 148 149
    const VectorLikeType& vec) {
  PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
                    platform::errors::InvalidArgument(
                        "Cub reduce Array: size not match. Received "
                        "vec.size() %d !=  ElementCount %d.",
                        vec.size(), ElementCount));
  size_t n = static_cast<size_t>(vec.size());
  paddle::framework::Array<T, ElementCount> ret;
150 151 152
  for (size_t i = 0; i < n; ++i) {
    ret[i] = vec[i];
  }
153 154 155 156 157
  return ret;
}

}  // namespace detail

158 159
using Tensor = framework::Tensor;

160
enum ReduceType {
161 162
  kReduceAll = 0x00,        // when reduce_rank == x_rank
  kReduceLastDim = 0x01,    // when reduce_dim[0] == x_dim.size() - 1;
163
  kReduceHigherDim = 0x02,  // ReduceFirstDim or reduceSecondDim
164
  kReduceAny = 0x03,        // when reduce_dim.size() > 1
165 166 167 168 169 170 171 172 173 174 175 176
};

// reduce config
template <typename Ty>
struct ReduceConfig {
  ReduceConfig(std::vector<int> origin_reduce_dims, std::vector<int> x_dim)
      : reduce_dims_origin(origin_reduce_dims), x_dim(x_dim) {}

  // get the parameters of reduceKernel
  void Run() {
    // step1: update the reduce_dim left_dim and x_dim
    SetReduceDim();
177

178 179
    // step2: get the strides of dim for reduceAny and reduceLastDim
    SetStrides();
180

181 182
    // step3: get the type of reduce
    SetReduceType();
183

184 185 186 187 188 189
    // step4: set the block and grid for launch kernel
    SetBlockDim();
  }

  // when should_reduce_again is true, we need malloc temp space for temp data
  void SetOutputData(Ty* y_data, const platform::Place& place,
190
                     framework::Tensor* tmp) {
191
    if (should_reduce_again) {
192
      output_data = tmp->mutable_data<Ty>(
193
          framework::make_ddim(
194
              {static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}),
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
          place);
    } else {
      output_data = y_data;
    }
  }

 private:
  // set reduce_dim, left_dim and update x_dim
  // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
  //     --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
  void SetReduceDim() {
    std::set<int> reduce_set;
    for (auto e : reduce_dims_origin) {
      auto pos = e >= 0 ? e : e + x_dim.size();
      reduce_set.insert(pos);
    }
211

212 213
    std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
    std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
214 215 216 217 218 219 220 221 222 223

    // update reduce_dim and x_dim
    std::vector<int> x_new_dim;

    reduce_dim.push_back(reduce_dim_temp[0]);
    x_new_dim.push_back(x_dim[0]);

    int idx_reduce = 1;
    int num = 0;

224
    if (reduce_dim_temp.size() > 1) {
225 226 227 228 229 230 231 232 233 234 235 236 237 238
      for (int i = 1; i < x_dim.size(); i++) {
        if ((idx_reduce < reduce_dim_temp.size()) &&
            (i == reduce_dim_temp[idx_reduce])) {
          int result =
              reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
          bool is_equal = ((result - num) == 1);
          if (is_equal) {
            x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
            num++;
          } else {
            reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
            x_new_dim.push_back(x_dim[i]);
          }
          idx_reduce++;
239
        } else {
240
          x_new_dim.push_back(x_dim[i]);
241 242 243
        }
      }
    } else {
244
      x_new_dim = x_dim;
245 246
    }

247 248 249 250 251
    // update x_dim
    x_dim = x_new_dim;
    std::vector<int>().swap(x_new_dim);

    std::vector<int> reduce_dim_new;
252 253 254 255 256
    int is_reduced = 0;
    for (auto e : reduce_dim) {
      is_reduced |= 1 << e;
    }

257 258
    std::vector<int>().swap(reduce_dim);

259 260
    for (int i = 0; i < x_dim.size(); i++) {
      if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
261
        x_new_dim.push_back(x_dim[i]);
262
        if ((is_reduced >> i) & 1)
263
          reduce_dim_new.push_back(x_new_dim.size() - 1);
264
      } else {
265
        x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
266 267 268
      }
    }

269 270
    x_dim = x_new_dim;
    reduce_dim = reduce_dim_new;
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295

    int x_rank = static_cast<int>(x_dim.size());
    std::set<int> left_set;

    for (int i = 0; i < x_rank; ++i) {
      left_set.insert(i);
    }

    for (auto e : reduce_dim) {
      left_set.erase(e);
    }

    left_dim.assign(left_set.begin(), left_set.end());
  }

  // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
  // eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
  //     --SetStrides--> x_strides= [6,1], reduce_strides = [1],
  //     left_strides = [1]
  void SetStrides() {
    std::vector<int> idx_dim;
    for (int i = 0; i < x_dim.size(); i++) {
      idx_dim.push_back(i);
    }

296 297 298
    x_strides = detail::GetDimStrides(x_dim, idx_dim);
    reduce_strides = detail::GetDimStrides(x_dim, reduce_dim);
    left_strides = detail::GetDimStrides(x_dim, left_dim);
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
    reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];

    left_num = 1;
    if (left_dim.size()) {
      left_num = left_strides[0] * x_dim[left_dim[0]];
    }
  }

  // get the reduceType
  // eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
  //     x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
  //     x_dim = [8] reduce_dim = [0] --> reduceAll
  //     x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
  void SetReduceType() {
    int rank = x_dim.size();
    int reduce_rank = reduce_dim.size();
315 316
    bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
                           (left_num > REDUCE_SPLIT_BOUNDARY);
317 318 319 320 321 322

    if (rank == reduce_rank) {
      reduce_type = static_cast<int>(ReduceType::kReduceAll);

    } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
      reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
323 324 325

    } else if (reduce_rank == 1 &&
               ((rank == 2 && is_large_enough) || rank != 2)) {
326 327 328 329 330 331 332 333 334 335 336 337 338 339
      // ReduceFirstDim and reduceSecondDim
      reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);

    } else {
      reduce_type = static_cast<int>(ReduceType::kReduceAny);
    }
  }

  // set block and grid for launch kernel
  // for ReduceHigherDim: if block is enough -> splite reduce_num
  //                     else init block(32, 1) grid(block_num, 1)
  // for others: block(block_num, 1) , grid(left_num, 1)
  void SetBlockDim() {
    // init
340
    int block_num = detail::GetBlockDim(reduce_num);
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    should_reduce_again = false;

    dim3 block_dim(block_num, 1);
    dim3 grid_dim(left_num, 1);
    blocking_size = reduce_num;

    if (reduce_type == ReduceType::kReduceHigherDim) {
      int last_dim_num = x_dim.back();
      // update left_num
      int grid_z = left_num / last_dim_num;
      left_num = last_dim_num;

      block_dim.z = 1;
      grid_dim.z = grid_z;

      int device_id = platform::GetCurrentDeviceId();
      int max_mp = platform::GetCUDAMultiProcessors(device_id);
      int max_threads_per_mp =
          platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
      int max_threads = max_threads_per_mp * max_mp;

      // init
      int num_block = (max_threads / left_num);

365
      if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
        blocking_size = detail::GetLastPow2(reduce_num / num_block);

        if (blocking_size <= 1) {
          blocking_size = detail::GetLastPow2(sqrt(reduce_num));
        } else if (blocking_size * 2 < reduce_num) {
          blocking_size *= 2;
        }

        should_reduce_again = true;

        block_dim.x = 32;
        block_dim.y = 1;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size;

      } else {
        block_dim.x = 32;
        block_dim.y = 1;
        blocking_size = reduce_num;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = 1;
      }
    }

    block = block_dim;
    grid = grid_dim;
  }

 public:
  std::vector<int> reduce_dims_origin;
  std::vector<int> reduce_dim;
  std::vector<int> x_dim;
  std::vector<int> left_dim;
  std::vector<int> x_strides;
  std::vector<int> left_strides;
  std::vector<int> reduce_strides;

  int reduce_type;
  int reduce_num;
  int left_num;
  int blocking_size;
  bool should_reduce_again;

  Ty* output_data;

  dim3 block;
  dim3 grid;
};

415 416 417
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
418 419 420 421 422 423 424 425 426 427
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim>
__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
                                              ReduceOp reducer,
                                              TransformOp transformer, Ty init,
                                              int reduce_num) {
  __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
  int idx_x = blockIdx.x * reduce_num;
  int idx_y = threadIdx.x;
  Ty reduce_var = init;
428 429 430 431
  for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) {
    reduce_var =
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
  }
432 433 434 435 436 437
  __syncthreads();

  reduce_var =
      cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);

  if (threadIdx.x == 0) {
438
    y[blockIdx.x] = reduce_var;
439 440 441
  }
}

442 443 444 445 446
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
//     if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
//     else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
447 448 449 450 451 452 453 454 455 456 457 458 459 460
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
                                                ReduceOp reducer,
                                                TransformOp transformer,
                                                Ty init, int reduce_num,
                                                int left_num, int block_size) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int idy = blockIdx.y * block_size;

  Ty reduce_var = init;

  if (idx < left_num) {
    int loop = reduce_num - idy;
    loop = loop > block_size ? block_size : loop;
461

462 463
    for (int iy = 0; iy < loop; iy++) {
      int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
464
      reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
465
    }
466

467
    y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
468
        reduce_var;
469 470 471
  }
}

472 473 474
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
475 476 477
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim, int Rank, int ReduceRank>
__device__ __forceinline__ void ReduceAny(
478
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
    int reduce_num, paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
  __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;

  int sub_index[Rank];
  int left_idx = blockIdx.x;
  for (int i = 0; i < Rank - ReduceRank; ++i) {
    sub_index[left_dim[i]] = left_idx / left_strides[i];
    left_idx %= left_strides[i];
  }

  int reduce_idx = threadIdx.x;
  for (int j = 0; j < ReduceRank; ++j) {
    sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
    reduce_idx %= reduce_strides[j];
  }

  int idx_x = 0;
500 501 502 503
  for (int k = 0; k < Rank; ++k) {
    idx_x += (sub_index[k] * x_strides[k]);
  }
  Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
504 505 506

  for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
    int reduce_idx = i;
507

508 509 510 511 512 513
    for (int j = 0; j < ReduceRank; ++j) {
      sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
      reduce_idx %= reduce_strides[j];
    }

    int idx_x = 0;
514 515 516 517 518 519
    for (int k = 0; k < Rank; ++k) {
      idx_x += (sub_index[k] * x_strides[k]);
    }

    reduce_var = static_cast<Ty>(
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
520 521 522 523 524 525 526
  }
  __syncthreads();

  reduce_var =
      cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);

  if (threadIdx.x == 0) {
527
    y[blockIdx.x] = reduce_var;
528 529 530
  }
}

531
// module function designed for global function
532 533 534 535 536 537 538 539 540 541
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim, int Rank, int ReduceRank, int ReduceType>
__device__ __forceinline__ void ReduceModule(
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
    int reduce_num, int left_num, int blocking_size,
    paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
542
  // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
543 544 545 546
  if (ReduceType == ReduceType::kReduceLastDim) {
    ReduceLastDim<Tx, Ty, ReduceOp, TransformOp, BlockDim>(
        x, y, reducer, transformer, init, reduce_num);

547
    // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
548 549 550 551
  } else if (ReduceType == ReduceType::kReduceHigherDim) {
    ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
        x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);

552
    // reduce_rank >= 2
553 554
  } else {
    ReduceAny<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank>(
555
        x, y, reducer, transformer, reduce_num, x_strides, reduce_dim,
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
        reduce_strides, left_dim, left_strides);
  }
}

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim, int Rank, int ReduceRank, int ReduceType>
__global__ void ReduceKernelFunction(
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
    int reduce_num, int left_num, int block_size,
    paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
  ReduceModule<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank,
               ReduceType>(x, y, reducer, transformer, init, reduce_num,
                           left_num, block_size, x_strides, reduce_dim,
                           reduce_strides, left_dim, left_strides);
}

template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
          typename TransformOp, int kRank, int kReduceRank>
578 579
static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer,
                         const TransformOp& transformer, Ty init,
580
                         gpuStream_t stream, ReduceConfig<Ty> config) {
581 582 583 584 585 586 587 588 589 590 591 592 593
#define CUB_REDUCE_TYPE_CASE(type)                                             \
  case type: {                                                                 \
    constexpr auto kReduceType = type;                                         \
    ReduceKernelFunction<                                                      \
        Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank,           \
        kReduceType><<<config.grid, config.block, 0, stream>>>(                \
        x_data, config.output_data, reducer, transformer, init,                \
        config.reduce_num, config.left_num, config.blocking_size,              \
        detail::VectorToArray<int, kRank>(config.x_strides),                   \
        detail::VectorToArray<int, kReduceRank>(config.reduce_dim),            \
        detail::VectorToArray<int, kReduceRank>(config.reduce_strides),        \
        detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),      \
        detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides)); \
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
  } break

  switch (config.reduce_type) {
    CUB_REDUCE_TYPE_CASE(1);  // reduceLastDim
    CUB_REDUCE_TYPE_CASE(2);  // ReduceHigherDim
    CUB_REDUCE_TYPE_CASE(3);  // reduceAny
  }

  if (config.should_reduce_again) {
    dim3 block(config.block.x, 1, 1);
    dim3 grid(config.grid.x, 1, config.grid.z);

    ReduceKernelFunction<
        Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, 128, kRank, kReduceRank,
        ReduceType::kReduceHigherDim><<<grid, block, 0, stream>>>(
609 610 611 612 613 614 615 616
        config.output_data, y_data, reducer,
        detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
        config.left_num, config.grid.y,
        detail::VectorToArray<int, kRank>(config.x_strides),
        detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
        detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
        detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
        detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
617 618 619 620 621
  }
}

template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
          typename TransformOp>
622
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
623
                               const ReduceOp& reducer,
624
                               const TransformOp& transformer, Ty init,
625 626 627 628 629 630 631 632 633 634 635 636 637
                               gpuStream_t stream, ReduceConfig<Ty> config) {
  int reduce_rank = config.reduce_strides.size();
  int rank = config.x_strides.size();

#define CUB_RANK_CASE(i, ...)             \
  case i: {                               \
    constexpr auto kRank = i;             \
    switch (reduce_rank) { __VA_ARGS__; } \
  } break

#define CUB_REDUCE_RANK_CASE(i, ...)                                           \
  case i: {                                                                    \
    constexpr auto kReduceRank = i;                                            \
638 639
    LaunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
        x_data, y_data, reducer, transformer, init, stream, config);           \
640 641
  } break

642
  detail::CheckReduceRank(reduce_rank, rank);
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
  switch (rank) {
    CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););

    CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););

    CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););

    CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););

    CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););

    CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););

    CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););

    CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
  }

#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
664 665 666 667 668 669

template <typename Tx, typename Ty,
          template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
                             std::vector<int> origin_reduce_dims,
                             gpuStream_t stream) {
670 671
  auto x_dim = framework::vectorize<int>(x.dims());
  auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
672
  config.Run();  // get the parameters of LaunchReduceKernel
673 674 675 676

  auto x_data = x.data<Tx>();
  auto y_data = y->mutable_data<Ty>(x.place());

677
  // after config.run()
678 679 680
  // SetOutputData for ReduceHigherDim when should_reduce_again is true,
  //   temp_output should be stored temp_data in output_data space or stored in
  //   y_data;
681 682
  framework::Tensor tmp;
  config.SetOutputData(y_data, x.place(), &tmp);
683 684 685 686 687 688 689

  if (config.reduce_num == 1) {
    auto out_dims = y->dims();
    framework::TensorCopy(x, y->place(), y);
    y->Resize(out_dims);
    return;
  }
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
  using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
  auto reducer = ReduceOp<Tx, Ty>();
  // launch CUB::Reduce
  if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
    cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
        x_data, TransformOp(config.reduce_num));
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
    framework::Tensor tmp;
    auto* temp_storage = tmp.mutable_data<uint8_t>(
        framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
        x.place());
    cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
707

708 709 710 711 712 713 714 715 716
    return;
  }

#define CUB_BLOCK_DIM_CASE(block_dim)                                     \
  case block_dim: {                                                       \
    constexpr auto kBlockDim = block_dim;                                 \
    LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp<Tx, Ty>, TransformOp>( \
        x_data, y_data, reducer, TransformOp(config.reduce_num),          \
        reducer.initial(), stream, config);                               \
717 718
  } break

719
  switch (detail::GetBlockDim(config.reduce_num)) {
720 721 722 723 724 725 726 727 728 729 730 731
    CUB_BLOCK_DIM_CASE(256);
    CUB_BLOCK_DIM_CASE(128);
    CUB_BLOCK_DIM_CASE(64);
    CUB_BLOCK_DIM_CASE(32);
    CUB_BLOCK_DIM_CASE(16);
    CUB_BLOCK_DIM_CASE(8);
    CUB_BLOCK_DIM_CASE(4);
    CUB_BLOCK_DIM_CASE(2);
  }
#undef CUB_BLOCK_DIM_CASE
}

732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
template <typename Tx, template <typename, typename> class ReduceOp>
struct TensorReduceFunc {
  const framework::Tensor& x;
  framework::Tensor* y;
  std::vector<int> origin_reduce_dims;
  gpuStream_t stream;
  TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
                   std::vector<int> origin_reduce_dims, gpuStream_t stream)
      : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}

  template <typename Ty>
  void apply() const {
    TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
  }
};

template <typename T, template <typename, typename> class ReduceOp>
class ReduceCudaKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    bool reduce_all = context.Attr<bool>("reduce_all");
    const Tensor* input = context.Input<Tensor>("X");
    Tensor* output = context.Output<Tensor>("Out");
    auto out_dtype = context.Attr<int>("out_dtype");
    std::vector<int> dims = context.Attr<std::vector<int>>("dim");

    std::vector<int> reduce_dims =
        detail::GetReduceDim(dims, input->dims().size(), reduce_all);

    gpuStream_t stream = context.cuda_device_context().stream();
    if (out_dtype >= 0) {
      framework::VisitDataTypeSmall(
          static_cast<framework::proto::VarType::Type>(out_dtype),
          TensorReduceFunc<T, ReduceOp>(*input, output, reduce_dims, stream));
    } else {
      TensorReduceFunctorImpl<T, T, ReduceOp>(*input, output, reduce_dims,
                                              stream);
    }
  }
};

773 774
}  // namespace operators
}  // namespace paddle