reduce_op.cu.h 24.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/fluid/framework/array.h"
33
#include "paddle/fluid/framework/op_registry.h"
34 35 36
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"

37 38 39
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512

40 41 42 43 44
namespace paddle {
namespace operators {
namespace detail {

// Post processing function for sum, max, min, prod, any
45
template <typename Tx, typename Ty = Tx>
46
struct IdentityFunctor {
47
  HOSTDEVICE explicit inline IdentityFunctor(int n) {}
48

49 50 51
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x);
  }
52 53 54 55 56
};

// Post processing function for mean
template <typename T>
struct DivideFunctor {
57
  HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
58

59
  HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
60 61 62 63 64 65 66 67 68 69 70 71 72 73

 private:
  T n_inv;
};

static inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

74 75 76
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
                                             const std::vector<int>& idx) {
77 78 79 80 81 82 83 84 85 86 87
  int n = static_cast<int>(idx.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[idx[i + 1]];
  }
  return strides;
}

#ifdef __HIPCC__
88
constexpr int kMaxThread = 256;
89
#else
90
constexpr int kMaxThread = 128;
91 92
#endif

93 94 95
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
  return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim);
96 97
}

98 99
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
  if (rank % 2 == 0) {
    PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
                      platform::errors::InvalidArgument(
                          "ReduceOp: invalid reduce rank. When rank = %d, "
                          "reduce_rank must be %d, but got %d.",
                          rank, rank / 2, reduce_rank));
  } else {
    auto lower_rank = (rank - 1) / 2;
    auto upper_rank = (rank + 1) / 2;
    PADDLE_ENFORCE_EQ(
        reduce_rank == lower_rank || reduce_rank == upper_rank, true,
        platform::errors::InvalidArgument(
            "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
            "must be %d or %d, but got %d.",
            rank, lower_rank, upper_rank, reduce_rank));
  }
}

118
// convert dims from vector to array
119
template <typename T, size_t ElementCount, typename VectorLikeType>
120
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
121 122 123 124 125 126 127 128
    const VectorLikeType& vec) {
  PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
                    platform::errors::InvalidArgument(
                        "Cub reduce Array: size not match. Received "
                        "vec.size() %d !=  ElementCount %d.",
                        vec.size(), ElementCount));
  size_t n = static_cast<size_t>(vec.size());
  paddle::framework::Array<T, ElementCount> ret;
129 130 131
  for (size_t i = 0; i < n; ++i) {
    ret[i] = vec[i];
  }
132 133 134 135 136
  return ret;
}

}  // namespace detail

137 138
using Tensor = framework::Tensor;

139
enum ReduceType {
140 141
  kReduceAll = 0x00,        // when reduce_rank == x_rank
  kReduceLastDim = 0x01,    // when reduce_dim[0] == x_dim.size() - 1;
142
  kReduceHigherDim = 0x02,  // ReduceFirstDim or reduceSecondDim
143
  kReduceAny = 0x03,        // when reduce_dim.size() > 1
144 145 146 147 148
};

// reduce config
template <typename Ty>
struct ReduceConfig {
149 150 151
  ReduceConfig(const std::vector<int>& origin_reduce_dims,
               const std::vector<int>& origin_x_dim)
      : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
152 153 154 155 156

  // get the parameters of reduceKernel
  void Run() {
    // step1: update the reduce_dim left_dim and x_dim
    SetReduceDim();
157

158 159
    // step2: get the strides of dim for reduceAny and reduceLastDim
    SetStrides();
160

161 162
    // step3: get the type of reduce
    SetReduceType();
163

164 165 166 167 168 169
    // step4: set the block and grid for launch kernel
    SetBlockDim();
  }

  // when should_reduce_again is true, we need malloc temp space for temp data
  void SetOutputData(Ty* y_data, const platform::Place& place,
170
                     framework::Tensor* tmp) {
171
    if (should_reduce_again) {
172
      output_data = tmp->mutable_data<Ty>(
173
          framework::make_ddim(
174
              {static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}),
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
          place);
    } else {
      output_data = y_data;
    }
  }

 private:
  // set reduce_dim, left_dim and update x_dim
  // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
  //     --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
  void SetReduceDim() {
    std::set<int> reduce_set;
    for (auto e : reduce_dims_origin) {
      auto pos = e >= 0 ? e : e + x_dim.size();
      reduce_set.insert(pos);
    }
191

192 193
    std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
    std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
194 195 196 197 198 199 200 201 202 203

    // update reduce_dim and x_dim
    std::vector<int> x_new_dim;

    reduce_dim.push_back(reduce_dim_temp[0]);
    x_new_dim.push_back(x_dim[0]);

    int idx_reduce = 1;
    int num = 0;

204
    if (reduce_dim_temp.size() > 1) {
205 206 207 208 209 210 211 212 213 214 215 216 217 218
      for (int i = 1; i < x_dim.size(); i++) {
        if ((idx_reduce < reduce_dim_temp.size()) &&
            (i == reduce_dim_temp[idx_reduce])) {
          int result =
              reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
          bool is_equal = ((result - num) == 1);
          if (is_equal) {
            x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
            num++;
          } else {
            reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
            x_new_dim.push_back(x_dim[i]);
          }
          idx_reduce++;
219
        } else {
220
          x_new_dim.push_back(x_dim[i]);
221 222 223
        }
      }
    } else {
224
      x_new_dim = x_dim;
225 226
    }

227 228 229 230 231
    // update x_dim
    x_dim = x_new_dim;
    std::vector<int>().swap(x_new_dim);

    std::vector<int> reduce_dim_new;
232 233 234 235 236
    int is_reduced = 0;
    for (auto e : reduce_dim) {
      is_reduced |= 1 << e;
    }

237 238
    std::vector<int>().swap(reduce_dim);

239 240
    for (int i = 0; i < x_dim.size(); i++) {
      if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
241
        x_new_dim.push_back(x_dim[i]);
242
        if ((is_reduced >> i) & 1)
243
          reduce_dim_new.push_back(x_new_dim.size() - 1);
244
      } else {
245
        x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
246 247 248
      }
    }

249 250
    x_dim = x_new_dim;
    reduce_dim = reduce_dim_new;
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

    int x_rank = static_cast<int>(x_dim.size());
    std::set<int> left_set;

    for (int i = 0; i < x_rank; ++i) {
      left_set.insert(i);
    }

    for (auto e : reduce_dim) {
      left_set.erase(e);
    }

    left_dim.assign(left_set.begin(), left_set.end());
  }

  // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
  // eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
  //     --SetStrides--> x_strides= [6,1], reduce_strides = [1],
  //     left_strides = [1]
  void SetStrides() {
    std::vector<int> idx_dim;
    for (int i = 0; i < x_dim.size(); i++) {
      idx_dim.push_back(i);
    }

276 277 278
    x_strides = detail::GetDimStrides(x_dim, idx_dim);
    reduce_strides = detail::GetDimStrides(x_dim, reduce_dim);
    left_strides = detail::GetDimStrides(x_dim, left_dim);
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];

    left_num = 1;
    if (left_dim.size()) {
      left_num = left_strides[0] * x_dim[left_dim[0]];
    }
  }

  // get the reduceType
  // eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
  //     x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
  //     x_dim = [8] reduce_dim = [0] --> reduceAll
  //     x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
  void SetReduceType() {
    int rank = x_dim.size();
    int reduce_rank = reduce_dim.size();
295 296
    bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
                           (left_num > REDUCE_SPLIT_BOUNDARY);
297 298 299 300 301 302

    if (rank == reduce_rank) {
      reduce_type = static_cast<int>(ReduceType::kReduceAll);

    } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
      reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
303 304 305

    } else if (reduce_rank == 1 &&
               ((rank == 2 && is_large_enough) || rank != 2)) {
306 307 308 309 310 311 312 313 314 315 316 317 318 319
      // ReduceFirstDim and reduceSecondDim
      reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);

    } else {
      reduce_type = static_cast<int>(ReduceType::kReduceAny);
    }
  }

  // set block and grid for launch kernel
  // for ReduceHigherDim: if block is enough -> splite reduce_num
  //                     else init block(32, 1) grid(block_num, 1)
  // for others: block(block_num, 1) , grid(left_num, 1)
  void SetBlockDim() {
    // init
320
    int block_num = detail::GetBlockDim(reduce_num);
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
    should_reduce_again = false;

    dim3 block_dim(block_num, 1);
    dim3 grid_dim(left_num, 1);
    blocking_size = reduce_num;

    if (reduce_type == ReduceType::kReduceHigherDim) {
      int last_dim_num = x_dim.back();
      // update left_num
      int grid_z = left_num / last_dim_num;
      left_num = last_dim_num;

      block_dim.z = 1;
      grid_dim.z = grid_z;

      int device_id = platform::GetCurrentDeviceId();
      int max_mp = platform::GetCUDAMultiProcessors(device_id);
      int max_threads_per_mp =
          platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
      int max_threads = max_threads_per_mp * max_mp;

      // init
      int num_block = (max_threads / left_num);

345
      if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
        blocking_size = detail::GetLastPow2(reduce_num / num_block);

        if (blocking_size <= 1) {
          blocking_size = detail::GetLastPow2(sqrt(reduce_num));
        } else if (blocking_size * 2 < reduce_num) {
          blocking_size *= 2;
        }

        should_reduce_again = true;

        block_dim.x = 32;
        block_dim.y = 1;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size;

      } else {
        block_dim.x = 32;
        block_dim.y = 1;
        blocking_size = reduce_num;
        grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
        grid_dim.y = 1;
      }
    }

    block = block_dim;
    grid = grid_dim;
  }

 public:
  std::vector<int> reduce_dims_origin;
  std::vector<int> reduce_dim;
  std::vector<int> x_dim;
  std::vector<int> left_dim;
  std::vector<int> x_strides;
  std::vector<int> left_strides;
  std::vector<int> reduce_strides;

  int reduce_type;
  int reduce_num;
  int left_num;
  int blocking_size;
  bool should_reduce_again;

  Ty* output_data;

  dim3 block;
  dim3 grid;
};

395 396 397
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
398 399 400 401 402 403 404 405 406 407
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim>
__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
                                              ReduceOp reducer,
                                              TransformOp transformer, Ty init,
                                              int reduce_num) {
  __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
  int idx_x = blockIdx.x * reduce_num;
  int idx_y = threadIdx.x;
  Ty reduce_var = init;
408 409 410 411
  for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) {
    reduce_var =
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
  }
412 413 414 415 416 417
  __syncthreads();

  reduce_var =
      cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);

  if (threadIdx.x == 0) {
418
    y[blockIdx.x] = reduce_var;
419 420 421
  }
}

422 423 424 425 426
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
//     if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
//     else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
427 428 429 430 431 432 433 434 435 436 437 438 439 440
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
                                                ReduceOp reducer,
                                                TransformOp transformer,
                                                Ty init, int reduce_num,
                                                int left_num, int block_size) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int idy = blockIdx.y * block_size;

  Ty reduce_var = init;

  if (idx < left_num) {
    int loop = reduce_num - idy;
    loop = loop > block_size ? block_size : loop;
441

442 443
    for (int iy = 0; iy < loop; iy++) {
      int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
444
      reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
445
    }
446

447
    y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
448
        reduce_var;
449 450 451
  }
}

452 453 454
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
455 456 457
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
          int BlockDim, int Rank, int ReduceRank>
__device__ __forceinline__ void ReduceAny(
458
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
    int reduce_num, paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
  __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;

  int sub_index[Rank];
  int left_idx = blockIdx.x;
  for (int i = 0; i < Rank - ReduceRank; ++i) {
    sub_index[left_dim[i]] = left_idx / left_strides[i];
    left_idx %= left_strides[i];
  }

  int reduce_idx = threadIdx.x;
  for (int j = 0; j < ReduceRank; ++j) {
    sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
    reduce_idx %= reduce_strides[j];
  }

  int idx_x = 0;
480 481 482 483
  for (int k = 0; k < Rank; ++k) {
    idx_x += (sub_index[k] * x_strides[k]);
  }
  Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
484 485 486

  for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
    int reduce_idx = i;
487

488 489 490 491 492 493
    for (int j = 0; j < ReduceRank; ++j) {
      sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
      reduce_idx %= reduce_strides[j];
    }

    int idx_x = 0;
494 495 496 497 498 499
    for (int k = 0; k < Rank; ++k) {
      idx_x += (sub_index[k] * x_strides[k]);
    }

    reduce_var = static_cast<Ty>(
        reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
500 501 502 503 504 505 506
  }
  __syncthreads();

  reduce_var =
      cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);

  if (threadIdx.x == 0) {
507
    y[blockIdx.x] = reduce_var;
508 509 510
  }
}

511
// module function designed for global function
512
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
513
          int BlockDim, int Rank, int ReduceRank>
514 515
__device__ __forceinline__ void ReduceModule(
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
516
    int reduce_num, int left_num, int blocking_size, int reduce_type,
517 518 519 520 521
    paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
522
  // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
523
  if (reduce_type == ReduceType::kReduceLastDim) {
524 525 526
    ReduceLastDim<Tx, Ty, ReduceOp, TransformOp, BlockDim>(
        x, y, reducer, transformer, init, reduce_num);

527
    // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
528
  } else if (reduce_type == ReduceType::kReduceHigherDim) {
529 530 531
    ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
        x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);

532
    // reduce_rank >= 2
533 534
  } else {
    ReduceAny<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank>(
535
        x, y, reducer, transformer, reduce_num, x_strides, reduce_dim,
536 537 538 539 540
        reduce_strides, left_dim, left_strides);
  }
}

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
541
          int BlockDim, int Rank, int ReduceRank>
542 543
__global__ void ReduceKernelFunction(
    const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
544
    int reduce_num, int left_num, int block_size, int reduce_type,
545 546 547 548 549
    paddle::framework::Array<int, Rank> x_strides,
    paddle::framework::Array<int, ReduceRank> reduce_dim,
    paddle::framework::Array<int, ReduceRank> reduce_strides,
    paddle::framework::Array<int, Rank - ReduceRank> left_dim,
    paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
550 551 552 553
  ReduceModule<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank>(
      x, y, reducer, transformer, init, reduce_num, left_num, block_size,
      reduce_type, x_strides, reduce_dim, reduce_strides, left_dim,
      left_strides);
554 555
}

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp, int kRank,
          int kReduceRank>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
                               const ReduceOp& reducer, Ty init,
                               gpuStream_t stream, ReduceConfig<Ty> config) {
  using TransformOp = typename ReduceOp::Transformer;

  ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank,
                       kReduceRank><<<config.grid, config.block, 0, stream>>>(
      x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
      config.reduce_num, config.left_num, config.blocking_size,
      config.reduce_type, detail::VectorToArray<int, kRank>(config.x_strides),
      detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
      detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
      detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
      detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
572 573 574 575 576

  if (config.should_reduce_again) {
    dim3 block(config.block.x, 1, 1);
    dim3 grid(config.grid.x, 1, config.grid.z);

577 578
    ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, 128,
                         kRank, kReduceRank><<<grid, block, 0, stream>>>(
579 580
        config.output_data, y_data, reducer,
        detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
581
        config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
582 583 584 585 586
        detail::VectorToArray<int, kRank>(config.x_strides),
        detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
        detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
        detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
        detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
587 588 589
  }
}

590 591 592 593
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp>
static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
                             const ReduceOp& reducer, Ty init,
                             gpuStream_t stream, ReduceConfig<Ty> config) {
594 595 596 597 598 599 600 601 602
  int reduce_rank = config.reduce_strides.size();
  int rank = config.x_strides.size();

#define CUB_RANK_CASE(i, ...)             \
  case i: {                               \
    constexpr auto kRank = i;             \
    switch (reduce_rank) { __VA_ARGS__; } \
  } break

603 604 605 606 607
#define CUB_REDUCE_RANK_CASE(i, ...)                                    \
  case i: {                                                             \
    constexpr auto kReduceRank = i;                                     \
    LaunchReduceKernel<Tx, Ty, BlockDim, ReduceOp, kRank, kReduceRank>( \
        x_data, y_data, reducer, init, stream, config);                 \
608 609
  } break

610
  detail::CheckReduceRank(reduce_rank, rank);
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
  switch (rank) {
    CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););

    CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););

    CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););

    CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););

    CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););

    CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););

    CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););

    CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
  }

#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
632 633 634 635 636 637

template <typename Tx, typename Ty,
          template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
                             std::vector<int> origin_reduce_dims,
                             gpuStream_t stream) {
638 639
  auto x_dim = framework::vectorize<int>(x.dims());
  auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
640
  config.Run();  // get the parameters of LaunchReduceKernel
641

642
  // after config.run()
643 644 645
  // SetOutputData for ReduceHigherDim when should_reduce_again is true,
  //   temp_output should be stored temp_data in output_data space or stored in
  //   y_data;
646
  framework::Tensor tmp;
647 648
  auto x_data = x.data<Tx>();
  auto y_data = y->mutable_data<Ty>(x.place());
649 650 651 652 653 654 655

  if (config.reduce_num == 1) {
    auto out_dims = y->dims();
    framework::TensorCopy(x, y->place(), y);
    y->Resize(out_dims);
    return;
  }
656 657 658

  config.SetOutputData(y_data, x.place(), &tmp);

659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
  using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
  auto reducer = ReduceOp<Tx, Ty>();
  // launch CUB::Reduce
  if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
    cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
        x_data, TransformOp(config.reduce_num));
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
    framework::Tensor tmp;
    auto* temp_storage = tmp.mutable_data<uint8_t>(
        framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
        x.place());
    cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
676

677 678 679
    return;
  }

680 681 682 683 684
#define CUB_BLOCK_DIM_CASE(block_dim)                                \
  case block_dim: {                                                  \
    constexpr auto kBlockDim = block_dim;                            \
    ReduceKernelImpl<Tx, Ty, block_dim, ReduceOp<Tx, Ty>>(           \
        x_data, y_data, reducer, reducer.initial(), stream, config); \
685 686
  } break

687
  switch (detail::GetBlockDim(config.reduce_num)) {
688 689 690 691 692 693 694 695 696 697 698 699
    CUB_BLOCK_DIM_CASE(256);
    CUB_BLOCK_DIM_CASE(128);
    CUB_BLOCK_DIM_CASE(64);
    CUB_BLOCK_DIM_CASE(32);
    CUB_BLOCK_DIM_CASE(16);
    CUB_BLOCK_DIM_CASE(8);
    CUB_BLOCK_DIM_CASE(4);
    CUB_BLOCK_DIM_CASE(2);
  }
#undef CUB_BLOCK_DIM_CASE
}

700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715
template <typename Tx, template <typename, typename> class ReduceOp>
struct TensorReduceFunc {
  const framework::Tensor& x;
  framework::Tensor* y;
  std::vector<int> origin_reduce_dims;
  gpuStream_t stream;
  TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
                   std::vector<int> origin_reduce_dims, gpuStream_t stream)
      : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}

  template <typename Ty>
  void apply() const {
    TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
  }
};

716 717
}  // namespace operators
}  // namespace paddle