reduce_op.cu.h 34.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/fluid/framework/array.h"
33
#include "paddle/fluid/framework/op_registry.h"
34 35
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
36
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
37
#include "paddle/fluid/operators/cast_op.h"
38
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
39
#include "paddle/fluid/platform/cuda_device_function.h"
40
#include "paddle/fluid/platform/fast_divmod.h"
41

42 43
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
44
#define REDUCE_VEC_SIZE 4
45

46 47 48
namespace paddle {
namespace operators {

49
namespace kps = paddle::operators::kernel_primitives;
50

51
namespace details {
52 53 54 55 56 57 58 59 60 61

static inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

62 63
static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; }

64 65 66
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
                                             const std::vector<int>& idx) {
67 68 69 70 71 72 73 74 75 76
  int n = static_cast<int>(idx.size());
  if (n == 0) return std::vector<int>();
  std::vector<int> strides(n);
  strides.back() = 1;
  for (int i = n - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * dims[idx[i + 1]];
  }
  return strides;
}

77 78
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
79 80 81
  return block_dim >= kps::details::kReduceMaxThread
             ? kps::details::kReduceMaxThread
             : GetLastPow2(block_dim);
82 83
}

84 85
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  if (rank % 2 == 0) {
    PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
                      platform::errors::InvalidArgument(
                          "ReduceOp: invalid reduce rank. When rank = %d, "
                          "reduce_rank must be %d, but got %d.",
                          rank, rank / 2, reduce_rank));
  } else {
    auto lower_rank = (rank - 1) / 2;
    auto upper_rank = (rank + 1) / 2;
    PADDLE_ENFORCE_EQ(
        reduce_rank == lower_rank || reduce_rank == upper_rank, true,
        platform::errors::InvalidArgument(
            "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
            "must be %d or %d, but got %d.",
            rank, lower_rank, upper_rank, reduce_rank));
  }
}

104
// convert dims from vector to array
105
template <typename T, size_t ElementCount, typename VectorLikeType>
106
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
107
    const VectorLikeType& vec) {
108
  PADDLE_ENFORCE_LE(vec.size(), ElementCount,
109 110
                    platform::errors::InvalidArgument(
                        "Cub reduce Array: size not match. Received "
111
                        "vec.size() %d > ElementCount %d.",
112 113 114
                        vec.size(), ElementCount));
  size_t n = static_cast<size_t>(vec.size());
  paddle::framework::Array<T, ElementCount> ret;
115 116 117
  for (size_t i = 0; i < n; ++i) {
    ret[i] = vec[i];
  }
118 119 120
  return ret;
}

121
}  // namespace details
122

123
using Tensor = framework::Tensor;
124
constexpr int kMaxRank = framework::DDim::kMaxRank;
125

126
enum ReduceType {
127
  kReduceLastDim = 0x01,    // when reduce_dim[0] == x_dim.size() - 1;
128
  kReduceHigherDim = 0x02,  // ReduceFirstDim or reduceSecondDim
129
  kReduceAny = 0x03,        // when reduce_dim.size() > 1
130 131
};

132 133 134 135 136
struct IndexCalculator {
  IndexCalculator(int dim, const std::vector<int>& cal_dims,
                  const std::vector<int>& cal_strides,
                  const std::vector<int>& full_strides)
      : dim(dim) {
137 138
    dims = details::VectorToArray<int, kMaxRank>(cal_dims);
    strides = details::VectorToArray<int, kMaxRank>(full_strides);
139
    std::vector<platform::FastDivMod> cal_divmoders;
140 141
    // fast divmod
    for (auto i : cal_strides) {
142
      cal_divmoders.push_back(platform::FastDivMod(i));
143
    }
144
    divmoders =
145
        details::VectorToArray<platform::FastDivMod, kMaxRank>(cal_divmoders);
146 147
  }

148
  __device__ inline int operator()(int offset) const {
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    int index = 0;
#pragma unroll
    for (int i = 0; i < kMaxRank; ++i) {
      if (i == dim) {
        break;
      }
      auto divmod = divmoders[i].Divmod(offset);
      index += (divmod.val[0] * strides[dims[i]]);
      offset = divmod.val[1];
    }
    return index;
  }

  int dim;
  framework::Array<int, kMaxRank> dims;
  framework::Array<int, kMaxRank> strides;
165
  framework::Array<platform::FastDivMod, kMaxRank> divmoders;
166 167
};

168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
template <bool ReduceLastDim = false>
struct ReduceIndexMapping {
  const kps::DimConfig dim;
  HOSTDEVICE explicit ReduceIndexMapping(const kps::DimConfig& dims)
      : dim(dims) {}

  __device__ __forceinline__ int BlockIdX() {
#ifdef PADDLE_WITH_XPU2
    if (ReduceLastDim) {
      return (cluster_id() / dim.split_num_x % dim.split_num_y);
    } else {
      return cluster_id() % dim.split_num_x;
    }
#else
    return blockIdx.x;
#endif
  }

  __device__ __forceinline__ int BlockIdY() {
#ifdef PADDLE_WITH_XPU2
    if (ReduceLastDim) {
      return (cluster_id() % dim.split_num_x);
    } else {
      return (cluster_id() / dim.split_num_x % dim.split_num_y);
    }
#else
    return blockIdx.y;
#endif
  }

  __device__ __forceinline__ int BlockDimX() {
#ifdef PADDLE_WITH_XPU2
    return dim.deal_size_x;
#else
    return blockDim.x;
#endif
  }

  __device__ __forceinline__ int BlockDimY() {
#ifdef PADDLE_WITH_XPU2
    return dim.deal_size_y;
#else
    return blockDim.y;
#endif
  }

  __device__ __forceinline__ int GridDimX() {
#ifdef PADDLE_WITH_XPU2
    if (ReduceLastDim) {
      return dim.split_num_y;
    } else {
      return dim.split_num_x;
    }
#else
    return gridDim.x;
#endif
  }

  __device__ __forceinline__ int GridDimY() {
#ifdef PADDLE_WITH_XPU2
    if (ReduceLastDim) {
      return dim.split_num_x;
    } else {
      return dim.split_num_y;
    }
#else
    return gridDim.y;
#endif
  }

  __device__ __forceinline__ int GetLoopSize() {
#ifdef PADDLE_WITH_XPU2
    if (ReduceLastDim) {
      return dim.deal_size_y;
    } else {
      return dim.deal_size_x;
    }
#else
    return 1;
#endif
  }
};

251 252
// when reduce_type == kReduceLastDim this struct will be used
// for higher performance
253 254
struct OneDimIndexCal {
  explicit OneDimIndexCal(int num) : stride(num) {}
255 256 257 258 259

  __device__ inline int operator()(int index) const { return index * stride; }
  int stride;
};

260 261 262
// reduce config
template <typename Ty>
struct ReduceConfig {
263 264 265
  ReduceConfig(const std::vector<int>& origin_reduce_dims,
               const std::vector<int>& origin_x_dim)
      : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
266 267 268 269 270

  // get the parameters of reduceKernel
  void Run() {
    // step1: update the reduce_dim left_dim and x_dim
    SetReduceDim();
271

272 273
    // step2: get the strides of dim for reduceAny and reduceLastDim
    SetStrides();
274

275 276
    // step3: get the type of reduce
    SetReduceType();
277

278 279 280 281 282 283
    // step4: set the block and grid for launch kernel
    SetBlockDim();
  }

  // when should_reduce_again is true, we need malloc temp space for temp data
  void SetOutputData(Ty* y_data, const platform::Place& place,
284
                     framework::Tensor* tmp) {
285
    if (should_reduce_again) {
286
      output_data = tmp->mutable_data<Ty>(
287
          framework::make_ddim(
288
              {static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}),
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
          place);
    } else {
      output_data = y_data;
    }
  }

 private:
  // set reduce_dim, left_dim and update x_dim
  // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
  //     --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
  void SetReduceDim() {
    std::set<int> reduce_set;
    for (auto e : reduce_dims_origin) {
      auto pos = e >= 0 ? e : e + x_dim.size();
      reduce_set.insert(pos);
    }
305

306 307
    std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
    std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
308 309 310 311 312 313 314 315 316 317

    // update reduce_dim and x_dim
    std::vector<int> x_new_dim;

    reduce_dim.push_back(reduce_dim_temp[0]);
    x_new_dim.push_back(x_dim[0]);

    int idx_reduce = 1;
    int num = 0;

318
    if (reduce_dim_temp.size() > 1) {
319 320 321 322 323 324 325 326 327 328 329 330 331 332
      for (int i = 1; i < x_dim.size(); i++) {
        if ((idx_reduce < reduce_dim_temp.size()) &&
            (i == reduce_dim_temp[idx_reduce])) {
          int result =
              reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
          bool is_equal = ((result - num) == 1);
          if (is_equal) {
            x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
            num++;
          } else {
            reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
            x_new_dim.push_back(x_dim[i]);
          }
          idx_reduce++;
333
        } else {
334
          x_new_dim.push_back(x_dim[i]);
335 336 337
        }
      }
    } else {
338
      x_new_dim = x_dim;
339 340
    }

341 342 343 344 345
    // update x_dim
    x_dim = x_new_dim;
    std::vector<int>().swap(x_new_dim);

    std::vector<int> reduce_dim_new;
346 347 348 349 350
    int is_reduced = 0;
    for (auto e : reduce_dim) {
      is_reduced |= 1 << e;
    }

351 352
    std::vector<int>().swap(reduce_dim);

353 354
    for (int i = 0; i < x_dim.size(); i++) {
      if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
355
        x_new_dim.push_back(x_dim[i]);
356
        if ((is_reduced >> i) & 1)
357
          reduce_dim_new.push_back(x_new_dim.size() - 1);
358
      } else {
359
        x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
360 361 362
      }
    }

363 364
    x_dim = x_new_dim;
    reduce_dim = reduce_dim_new;
365 366 367 368 369 370 371 372 373 374 375 376 377

    int x_rank = static_cast<int>(x_dim.size());
    std::set<int> left_set;

    for (int i = 0; i < x_rank; ++i) {
      left_set.insert(i);
    }

    for (auto e : reduce_dim) {
      left_set.erase(e);
    }

    left_dim.assign(left_set.begin(), left_set.end());
378 379

    // if the last dim gets involved in reduction
380
    reduce_last_dim = (reduce_dim.back() == x_dim.size() - 1);
381 382 383 384 385 386 387 388 389 390 391 392
  }

  // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
  // eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
  //     --SetStrides--> x_strides= [6,1], reduce_strides = [1],
  //     left_strides = [1]
  void SetStrides() {
    std::vector<int> idx_dim;
    for (int i = 0; i < x_dim.size(); i++) {
      idx_dim.push_back(i);
    }

393 394 395
    x_strides = details::GetDimStrides(x_dim, idx_dim);
    reduce_strides = details::GetDimStrides(x_dim, reduce_dim);
    left_strides = details::GetDimStrides(x_dim, left_dim);
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
    reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];

    left_num = 1;
    if (left_dim.size()) {
      left_num = left_strides[0] * x_dim[left_dim[0]];
    }
  }

  // get the reduceType
  // eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
  //     x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
  //     x_dim = [8] reduce_dim = [0] --> reduceAll
  //     x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
  void SetReduceType() {
    int rank = x_dim.size();
    int reduce_rank = reduce_dim.size();
412 413 414
    bool is_last_dim =
        (rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1);
    if (rank == reduce_rank || is_last_dim) {
415
      reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
416
    } else if (reduce_rank == 1) {
417 418 419 420 421 422 423 424
// ReduceFirstDim and reduceSecondDim
#ifdef PADDLE_WITH_XPU2
      if (reduce_dim[0] == 0) {
        reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
      } else {
        reduce_type = static_cast<int>(ReduceType::kReduceAny);
      }
#else
425
      reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
426
#endif
427 428 429 430 431
    } else {
      reduce_type = static_cast<int>(ReduceType::kReduceAny);
    }
  }

432 433 434
  void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) {
    constexpr int min_reduce_num_per_thread = 16;
    constexpr int max_reduce_num_per_thread = 256;
435
    constexpr int max_num_threads = kps::details::kReduceMaxThread;
436 437

    // set block size.
438
    // 1. If reduce_last_dim == true, all the threads whose threadIdx.y are same
439 440
    //    will process the reduction for one output.
    //    The number of output for one block is blockDim.y;
441
    // 2. If reduce_last_dim == false, different threadIdx.x will process
442 443 444 445
    //    different reduction and gets the output separately. If it is
    //    necessary, it should reduce in block y.
    //    The number of output for one block is blockDim.x;
    int block_x, block_y;
446
    int grid_num, reduce_num_per_thread;
447 448 449
    if (reduce_last_dim) {
      block_x = details::GetBlockDim(reduce_num);
      block_y = details::GetBlockDim(left_num);
450 451 452
      block_dim->x = block_x;
      block_dim->y =
          std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
453 454
      grid_num = details::AlignUp(left_num, block_dim->y);
      reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->x);
455
    } else {
456 457
      block_x = details::GetBlockDim(left_num);
      block_y = details::GetBlockDim(reduce_num);
458 459 460 461 462
      block_dim->x = std::min(block_x, 32);
      block_dim->y =
          std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
      block_dim->x =
          std::min(block_x, static_cast<int>(max_num_threads / block_dim->y));
463 464
      grid_num = details::AlignUp(left_num, block_dim->x);
      reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->y);
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    }
    int device_id = platform::GetCurrentDeviceId();
    int max_mp = platform::GetCUDAMultiProcessors(device_id);
    int max_threads_per_mp =
        platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
    int max_threads = max_threads_per_mp * max_mp;
    int num_threads = block_dim->x * block_dim->y;
    int max_num_blocks = max_threads / num_threads;

    // set grid size.
    // Whether to set grid.y larger than 1, there are 3 following rules:
    // 1. The number that each thread process should no less than
    //    min_reduce_num_per_threadbut no more than max_reduce_num_per_thread;
    // 2. It should maximize the utilization of SM.
    // So we choose the minimum between input_split_num_1 and input_split_num_3
    // to make each thread process as mush data as possible. Meanwhile,
    // the number cannot be larger than max_reduce_num_per_thread, so we
    // choose the maximum between the result above and input_split_num_2.
    int input_split_num_1 =
484
        details::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread);
485
    int input_split_num_2 =
486 487
        details::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread);
    int input_split_num_3 = details::AlignUp(max_num_blocks, grid_num);
488 489 490 491 492 493 494 495 496 497

    grid_dim->x = grid_num;
    grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3),
                           input_split_num_2);
    // if grid.y > 1, we need launch reduce kernel again.
    if (grid_dim->y > 1) {
      should_reduce_again = true;
    }
  }

498 499 500 501
  // set block and grid for launch kernel
  // for ReduceHigherDim: if block is enough -> splite reduce_num
  //                     else init block(32, 1) grid(block_num, 1)
  // for others: block(block_num, 1) , grid(left_num, 1)
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
  void SetBlockDimForHigher(dim3* block_dim, dim3* grid_dim) {
    int last_dim_num = x_dim.back();
    // update left_num
    int grid_z = left_num / last_dim_num;
    left_num = last_dim_num;
    grid_dim->z = grid_z;
    int device_id = platform::GetCurrentDeviceId();
    int max_mp = platform::GetCUDAMultiProcessors(device_id);
    int max_threads_per_mp =
        platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
    int max_threads = max_threads_per_mp * max_mp;
    // init
    int num_block = (max_threads / left_num);
    block_dim->x = details::GetBlockDim(left_num);
    grid_dim->x = details::AlignUp(left_num, block_dim->x);
    blocking_size = reduce_num;

    if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
      blocking_size = details::GetLastPow2(reduce_num / num_block);
      if (blocking_size <= 1) {
        blocking_size = details::GetLastPow2(sqrt(reduce_num));
      } else if (blocking_size * 2 < reduce_num) {
        blocking_size *= 2;
      }
      should_reduce_again = true;
      grid_dim->y = details::AlignUp(reduce_num, blocking_size);
    }
  }

531 532
  void SetBlockDim() {
    // init
533
    int block_num = details::GetBlockDim(reduce_num);
534
    should_reduce_again = false;
535 536
    dim3 block_dim(block_num, 1, 1);
    dim3 grid_dim(left_num, 1, 1);
537
    blocking_size = reduce_num;
538 539 540 541 542 543 544 545 546 547 548 549 550
#ifdef PADDLE_WITH_XPU2
    if (reduce_last_dim) {
      block_dim.x = 128;
      block_dim.y = reduce_num;
      grid_dim.x = 8;
      grid_dim.y = 1;
    } else {
      block_dim.x = 128;
      block_dim.y = left_num;
      grid_dim.x = 8;
      grid_dim.y = 1;
    }
#else
551
    if (reduce_type == ReduceType::kReduceHigherDim) {
552
      SetBlockDimForHigher(&block_dim, &grid_dim);
553
    } else {
554
      SetBlockDimForReduceAny(&block_dim, &grid_dim);
555
    }
556
#endif
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575

    block = block_dim;
    grid = grid_dim;
  }

 public:
  std::vector<int> reduce_dims_origin;
  std::vector<int> reduce_dim;
  std::vector<int> x_dim;
  std::vector<int> left_dim;
  std::vector<int> x_strides;
  std::vector<int> left_strides;
  std::vector<int> reduce_strides;

  int reduce_type;
  int reduce_num;
  int left_num;
  int blocking_size;
  bool should_reduce_again;
576
  bool reduce_last_dim;
577 578 579 580 581 582

  Ty* output_data;

  dim3 block;
  dim3 grid;
};
583

584
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
585 586
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
587
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
588 589 590 591 592 593
          typename TransformOp, typename Calculator>
__global__ void ReduceAnyKernel(const Tx* x, Ty* y, ReduceOp reducer,
                                TransformOp transformer, MPType init,
                                int reduce_num, int left_num,
                                bool reduce_last_dim,
                                const Calculator reduce_index_calculator,
594 595
                                const Calculator left_index_calculator,
                                const kps::DimConfig dim) {
596
  int input_idx, left_idx, stride;
597 598
  int block_size = 0;
  bool need_store = true;
599
  int loop_left = 0;
600
  int tid = 0;
601
  // the last dim gets involved in reduction
602 603
  int store_offset = 0;
  int stride_left = 0;
604
  if (reduce_last_dim) {
605 606 607 608 609 610 611 612 613
    auto block = ReduceIndexMapping<true>(dim);
    input_idx = block.BlockIdY() * block.BlockDimX();
    left_idx = block.BlockIdX() * block.BlockDimY() + THREAD_ID_Y;
    stride = block.GridDimY() * block.BlockDimX();
    block_size = block.BlockDimX();
    need_store = (THREAD_ID_X == 0) && (left_idx < left_num);
    store_offset = block.BlockIdY() * left_num + left_idx;
    loop_left = min(block.GetLoopSize(), left_num - left_idx);
    stride_left = 1;
614
    tid = threadIdx.x;
615
  } else {
616 617 618 619 620 621 622 623 624
    auto block = ReduceIndexMapping<false>(dim);
    input_idx = block.BlockIdY() * block.BlockDimY();
    left_idx = block.BlockIdX() * block.BlockDimX() + THREAD_ID_X;
    stride = block.GridDimY() * block.BlockDimY();
    block_size = block.BlockDimY();
    need_store = (THREAD_ID_Y == 0) && (left_idx < left_num);
    loop_left = min(block.GetLoopSize(), left_num - left_idx);
    stride_left = block.BlockDimX() * block.GridDimX();
    store_offset = block.BlockIdY() * left_num + left_idx;
625
    tid = threadIdx.y;
626
  }
627 628
  // calculate the offset, means the addr where each thread really start.
  // 1. reduce for each thread
629 630 631 632 633 634
  MPType input_compute[REDUCE_VEC_SIZE];
  Tx input_reg[REDUCE_VEC_SIZE];
  for (int i = 0; i < loop_left; i += stride_left) {
    int input_offset = left_index_calculator(left_idx + i);
    const Tx* input = x + input_offset;
    MPType reduce_var = init;
635 636
    // load REDUCE_VEC_SIZE data once, and then compute
    int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride;
637 638
    for (; input_idx + block_size < bound;
         input_idx += REDUCE_VEC_SIZE * stride) {
639 640 641 642 643 644 645 646 647
      kps::ReadDataReduce<Tx, Tx, 1, REDUCE_VEC_SIZE, 1, 1, Calculator,
                          kps::IdentityFunctor<Tx>, false>(
          &input_reg[0], input, input_idx, reduce_index_calculator, 1,
          reduce_num, 1, stride, kps::IdentityFunctor<Tx>(), reduce_last_dim);
      kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
          &input_compute[0], &input_reg[0], transformer);
      kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
                  kps::details::ReduceMode::kLocalMode>(
          &reduce_var, &input_compute[0], reducer, reduce_last_dim);
648
    }
649

650 651 652 653 654 655 656 657 658 659 660 661 662 663
    kps::Init<MPType, REDUCE_VEC_SIZE>(&input_compute[0], init);
    kps::ReadDataReduce<Tx, MPType, 1, REDUCE_VEC_SIZE, 1, 1, Calculator,
                        TransformOp, true>(
        &input_compute[0], input, input_idx, reduce_index_calculator, 1,
        reduce_num - input_idx, 1, stride, transformer, reduce_last_dim);
    kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
                kps::details::ReduceMode::kLocalMode>(
        &reduce_var, &input_compute[0], reducer, reduce_last_dim);

    kps::Reduce<MPType, 1, 1, 1, ReduceOp, kps::details::kGlobalMode>(
        &reduce_var, &reduce_var, reducer, reduce_last_dim);
    if (need_store) {
      y[store_offset + i] = static_cast<Ty>(reduce_var);
    }
664 665 666
  }
}

667 668
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
          typename TransformOp>
669 670 671
__global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer,
                                      TransformOp transformer, MPType init,
                                      int reduce_num, int left_num,
672 673
                                      int blocking_size,
                                      const kps::DimConfig dim) {
674 675
  // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
  // function will be used
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
  auto block = ReduceIndexMapping<false>(dim);
  int idy = block.BlockIdY() * blocking_size;
  int idx = block.BlockIdX() * block.BlockDimX();
  int idz = BLOCK_ID_Z * left_num;
  int stride = dim.split_num_x * dim.deal_size_x;
  int size = left_num - dim.rem_x;
  int loop_size = min(reduce_num - idy, blocking_size);
  int store_offset = block.BlockIdY() * left_num + idz * block.GridDimY();
  int block_offset = idy * left_num + idz * reduce_num;
  const Tx* input = x + block_offset;
  Tx reduce_input;
  for (; idx < size; idx += stride) {
    MPType reduce_var = init;
    MPType reduce_compute = init;
    for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) {
      kps::ReadData<Tx, Tx, 1, 1, 1, false>(&reduce_input,
                                            input + loop_idx * left_num + idx,
                                            block.BlockDimX(), 1, 1, left_num);
      kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
          &reduce_compute, &reduce_input, transformer);
      kps::Reduce<MPType, 1, 1, 1, ReduceOp,
                  kps::details::ReduceMode::kLocalMode>(
          &reduce_var, &reduce_compute, reducer, false);
    }
    Ty result = static_cast<Ty>(reduce_var);
    kps::WriteData<Ty, 1, 1, 1, false>(y + store_offset + idx, &result,
                                       block.BlockDimX());
  }

  if (idx < left_num) {
    MPType reduce_var = init;
    MPType reduce_compute = init;
    for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) {
      kps::ReadData<Tx, Tx, 1, 1, 1, true>(&reduce_input,
                                           input + loop_idx * left_num + idx,
                                           dim.rem_x, 1, 1, left_num);
      kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
          &reduce_compute, &reduce_input, transformer);
      kps::Reduce<MPType, 1, 1, 1, ReduceOp,
                  kps::details::ReduceMode::kLocalMode>(
          &reduce_var, &reduce_compute, reducer, false);
    }
    Ty result = static_cast<Ty>(reduce_var);
    kps::WriteData<Ty, 1, 1, 1, true>(y + store_offset + idx, &result,
                                      dim.rem_x);
721 722 723
  }
}

724
template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
725
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
726
                               const ReduceOp& reducer, MPType init,
727 728
                               gpuStream_t stream, ReduceConfig<Ty> config) {
  using TransformOp = typename ReduceOp::Transformer;
729 730 731 732 733

  if (config.reduce_type == kReduceLastDim) {
    int stride_reduce = 1;
    int stride_left = config.reduce_num;
    // for higher performance
734 735
    auto reduce_index_calculator = OneDimIndexCal(stride_reduce);
    auto left_index_calculator = OneDimIndexCal(stride_left);
736

737 738 739 740 741 742 743 744 745 746 747 748
    kps::DimConfig dim =
        kps::DimConfig(config.grid.x, config.grid.y, config.grid.z,
                       config.block.x, config.block.y, 0);
    dim.SetRem(config.reduce_num % config.block.x, 0, 0);

#ifdef PADDLE_WITH_XPU2
    ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
                    OneDimIndexCal><<<8, 128, stream>>>(
        x_data, config.output_data, reducer, TransformOp(config.reduce_num),
        init, config.reduce_num, config.left_num, config.reduce_last_dim,
        reduce_index_calculator, left_index_calculator, dim);
#else
749
    ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
750
                    OneDimIndexCal><<<config.grid, config.block, 0, stream>>>(
751 752
        x_data, config.output_data, reducer, TransformOp(config.reduce_num),
        init, config.reduce_num, config.left_num, config.reduce_last_dim,
753 754
        reduce_index_calculator, left_index_calculator, dim);
#endif
755 756 757 758 759 760 761 762 763

  } else {
    int reduce_rank = config.reduce_strides.size();
    int left_rank = config.left_strides.size();
    auto reduce_index_calculator =
        IndexCalculator(reduce_rank, config.reduce_dim, config.reduce_strides,
                        config.x_strides);
    auto left_index_calculator = IndexCalculator(
        left_rank, config.left_dim, config.left_strides, config.x_strides);
764 765 766 767 768 769 770 771 772 773 774 775 776

    kps::DimConfig dim =
        kps::DimConfig(config.grid.x, config.grid.y, config.grid.z,
                       config.block.x, config.block.y, 0);
    dim.SetRem(config.reduce_num % config.block.x, 0, 0);

#ifdef PADDLE_WITH_XPU2
    ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
                    IndexCalculator><<<8, 128, stream>>>(
        x_data, config.output_data, reducer, TransformOp(config.reduce_num),
        init, config.reduce_num, config.left_num, config.reduce_last_dim,
        reduce_index_calculator, left_index_calculator, dim);
#else
777 778 779 780
    ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
                    IndexCalculator><<<config.grid, config.block, 0, stream>>>(
        x_data, config.output_data, reducer, TransformOp(config.reduce_num),
        init, config.reduce_num, config.left_num, config.reduce_last_dim,
781 782
        reduce_index_calculator, left_index_calculator, dim);
#endif
783
  }
784 785

  if (config.should_reduce_again) {
786 787
    dim3 block;
    dim3 grid;
788
    if (config.reduce_last_dim) {
789
      block = dim3(32, 1, 1);
790
      grid = dim3(details::AlignUp(config.left_num, 32), 1, 1);
791 792 793 794
    } else {
      block = dim3(config.block.x, 1, 1);
      grid = dim3(config.grid.x, 1, config.grid.z);
    }
795

796 797 798 799 800 801 802 803 804 805 806 807
    auto last_index = OneDimIndexCal(1);
    auto first_index = OneDimIndexCal(config.left_num);
    kps::DimConfig dim =
        kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0);
    dim.SetRem(config.left_num % block.x, 0, 0);
#ifdef PADDLE_WITH_XPU2
    ReduceHigherDimKernel<Ty, Ty, MPType, ReduceOp,
                          kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
        config.output_data, y_data, reducer,
        kps::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
        config.left_num, config.grid.y, dim);
#else
808
    ReduceHigherDimKernel<
809
        Ty, Ty, MPType, ReduceOp,
810
        kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
811
        config.output_data, y_data, reducer,
812 813 814
        kps::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
        config.left_num, config.grid.y, dim);
#endif
815 816 817
  }
}

818 819 820 821 822
template <typename Tx, typename Ty,
          template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
                             std::vector<int> origin_reduce_dims,
                             gpuStream_t stream) {
823 824
  auto x_dim = framework::vectorize<int>(x.dims());
  auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
825
  config.Run();
826
  int numel = x.numel();
827
  // after config.run()
828
  // SetOutputData for ReduceHigherDim when should_reduce_again is true,
829 830
  // temp_output should be stored temp_data in output_data space or stored in
  // y_data;
831
  framework::Tensor tmp;
832 833
  auto x_data = x.data<Tx>();
  auto y_data = y->mutable_data<Ty>(x.place());
834 835 836

  if (config.reduce_num == 1) {
    auto out_dims = y->dims();
837 838 839 840 841 842 843 844 845 846
    if (x.type() == y->type()) {
      framework::TensorCopy(x, y->place(), y);
      y->Resize(out_dims);
    } else {
      auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
          paddle::platform::DeviceContextPool::Instance().Get(x.place()));
      framework::VisitDataType(
          static_cast<framework::proto::VarType::Type>(y->type()),
          CastOpFunctor<platform::CUDADeviceContext, Tx>(&x, y, *dev_ctx));
    }
847 848
    return;
  }
849 850

  config.SetOutputData(y_data, x.place(), &tmp);
851
  bool use_cub_reduce = (config.reduce_num == numel) &&
852 853 854 855 856
                        (!std::is_same<Tx, paddle::platform::float16>::value);
  if (use_cub_reduce) {
    // launch CUB::Reduce
    using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
    auto reducer = ReduceOp<Tx, Ty>();
857 858 859 860 861 862 863 864 865 866 867 868 869
    cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
        x_data, TransformOp(config.reduce_num));
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
    framework::Tensor tmp;
    auto* temp_storage = tmp.mutable_data<uint8_t>(
        framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
        x.place());
    cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
                              config.reduce_num, reducer, reducer.initial(),
                              stream);
870

871 872 873
    return;
  }

874 875
  using MPType = typename details::MPTypeTrait<Ty>::Type;
  auto reducer = ReduceOp<Tx, MPType>();
876 877 878 879 880 881 882 883 884
  // launch ReduceHigherDimKernel
  // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
  // function will be used
  // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
  //     if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
  //     32
  //     else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
  if (config.reduce_type == ReduceType::kReduceHigherDim) {
    using TransformOp = typename ReduceOp<Tx, MPType>::Transformer;
885 886 887 888 889 890 891 892 893 894 895 896 897
    kps::DimConfig dim =
        kps::DimConfig(config.grid.x, config.grid.y, config.grid.z,
                       config.block.x, config.blocking_size, 0);
    dim.SetRem(config.left_num % config.block.x,
               config.reduce_num % config.blocking_size, 0);

#ifdef PADDLE_WITH_XPU2
    ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>,
                          TransformOp><<<8, 128, stream>>>(
        x_data, config.output_data, reducer, TransformOp(config.reduce_num),
        reducer.initial(), config.reduce_num, config.left_num,
        config.blocking_size, dim);
#else
898 899 900 901 902
    ReduceHigherDimKernel<
        Tx, Ty, MPType, ReduceOp<Tx, MPType>,
        TransformOp><<<config.grid, config.block, 0, stream>>>(
        x_data, config.output_data, reducer, TransformOp(config.reduce_num),
        reducer.initial(), config.reduce_num, config.left_num,
903 904
        config.blocking_size, dim);
#endif
905 906 907 908

    if (config.should_reduce_again) {
      dim3 block = dim3(config.block.x, 1, 1);
      dim3 grid = dim3(config.grid.x, 1, config.grid.z);
909 910 911 912 913 914 915 916 917 918 919 920 921 922 923
      kps::DimConfig dim2 =
          kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0);
      dim2.SetRem(config.left_num % config.block.x, 0, 0);

#ifdef PADDLE_WITH_XPU2
      ReduceHigherDimKernel<
          Ty, Ty, MPType, ReduceOp<Tx, MPType>,
          kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
          config.output_data, y_data, reducer,
          kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(),
          config.grid.y, config.left_num, config.grid.y, dim2);
#else
      ReduceHigherDimKernel<
          Ty, Ty, MPType, ReduceOp<Tx, MPType>,
          kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
924
          config.output_data, y_data, reducer,
925 926 927
          kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(),
          config.grid.y, config.left_num, config.grid.y, dim2);
#endif
928 929 930 931 932 933 934
    }
    return;
  }

  // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
  // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
  // function will be used
935
  LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
936
      x_data, y_data, reducer, reducer.initial(), stream, config);
937 938
}

939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954
template <typename Tx, template <typename, typename> class ReduceOp>
struct TensorReduceFunc {
  const framework::Tensor& x;
  framework::Tensor* y;
  std::vector<int> origin_reduce_dims;
  gpuStream_t stream;
  TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
                   std::vector<int> origin_reduce_dims, gpuStream_t stream)
      : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}

  template <typename Ty>
  void apply() const {
    TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
  }
};

955 956
}  // namespace operators
}  // namespace paddle