softmax_cudnn_op.cu.h 22.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
18
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
19 20
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
21 22
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

namespace paddle {
namespace operators {

using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor;

// Vectorization trait 4 * sizeof(T)
template <typename T>
class VecT4 {};
template <>
class VecT4<double> {
 public:
  using Type = long4;
};
template <>
class VecT4<float> {
 public:
  using Type = int4;
};
template <>
class VecT4<platform::float16> {
 public:
  using Type = int2;
};

// Vectorization trait 2 * sizeof(T)
template <typename T>
class VecT2 {};
template <>
class VecT2<double> {
 public:
  using Type = int4;
};
template <>
class VecT2<float> {
 public:
  using Type = int2;
};
template <>
class VecT2<platform::float16> {
 public:
  using Type = int;
};

static inline int log2_ceil(int value) {
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
}

template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceSum(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
      T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
      sum[i] = sum[i] + sum_val;
    }
  }
}

template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceMax(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
      T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
      sum[i] = max(sum[i], max_val);
    }
  }
}

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
namespace kps = paddle::operators::kernel_primitives;

template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
  inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }

  __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
    return max(a, b);
  }
};

template <typename Tx, typename Ty = Tx>
struct ExpSubFunctor {
  HOSTDEVICE inline ExpSubFunctor() { y = static_cast<Tx>(0.0f); }

  HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x - y));
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct ExpMulFunctor {
  HOSTDEVICE inline ExpMulFunctor() { y = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x) * y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnarySubFunctor {
  HOSTDEVICE inline UnarySubFunctor() { y = static_cast<Tx>(0.0f); }

  HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnaryLogFunctor {
  HOSTDEVICE inline UnaryLogFunctor() {}

  HOSTDEVICE explicit inline UnaryLogFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::log(x));
  }
};

template <typename Tx, typename Ty>
struct DataTransFunctor {
  HOSTDEVICE inline DataTransFunctor() {}

  HOSTDEVICE explicit inline DataTransFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return x == -std::numeric_limits<Tx>::infinity()
               ? -std::numeric_limits<Ty>::infinity()
               : static_cast<Ty>(x);
  }
};

template <typename Tx, typename Ty = Tx>
struct UnaryDivFunctor {
  HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x * n_inv);
  }

 private:
  Tx n_inv;
};

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
/*
Core function of computing softmax forward for axis=-1.
The computation includes
  - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
  - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  - Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
          bool LogMode = false>
__global__ void WarpSoftmaxForward(T* softmax, const T* src,
                                   const int batch_size, const int stride,
                                   const int element_count) {
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
208 209
  constexpr int kLoops = kDimCeil / kWarpSize;
  constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
210 211
  constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
212 213 214 215
  constexpr int kStep = kBatchSize * kLoopsV * kVSize;
  constexpr int kVItem = kLoopsV * kVSize;
  constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
  using kMode = kps::details::ReduceMode;
216 217 218 219 220 221 222 223 224 225

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
  }

  // read data from global memory
226 227 228 229
  AccT srcdata[kBatchSize][kLoopsV][kVSize];
  kps::Init<AccT, kStep>(&srcdata[0][0][0], kLowInf);
  T src_tmp[kBatchSize][kLoopsV][kVSize];
  kps::Init<T, kStep>(&src_tmp[0][0][0], -std::numeric_limits<T>::infinity());
230 231
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
232 233 234 235 236 237 238
    int ptr = (first_batch + i) * stride;
    const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
    VecT* reg_v = reinterpret_cast<VecT*>(&src_tmp[i][0][0]);
    kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
        &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
    kps::ElementwiseUnary<T, AccT, kVItem, 1, 1, DataTransFunctor<T, AccT>>(
        &srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor<T, AccT>());
239 240
  }

241 242 243 244 245 246 247
  // compute max
  AccT max[kBatchSize];
  kps::Init<AccT, kBatchSize>(&max[0], kLowInf);
  kps::Reduce<AccT, kVItem, kBatchSize, 1, ReduceMaxFunctor<AccT>,
              kMode::kLocalMode>(&max[0], &srcdata[0][0][0],
                                 ReduceMaxFunctor<AccT>(), true);
  WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
248 249

  // compute sum
250
  AccT sum[kBatchSize] = {0};
251
  for (int i = 0; i < kBatchSize; ++i) {
252 253
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
        &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
254
  }
255 256 257
  kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>,
              kMode::kLocalMode>(&sum[0], &srcdata[0][0][0],
                                 kps::AddFunctor<AccT>(), true);
258 259
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

260 261
  // write result to global memory
  T out_tmp[kBatchSize][kLoopsV][kVSize];
262 263
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
264 265 266 267 268 269 270
    kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
        &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
    int softmax_ptr = (first_batch + i) * stride;
    VecT* softmax_v = reinterpret_cast<VecT*>(&softmax[softmax_ptr]);
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
    kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
        &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
  }
}

/*
Core function of computing softmax backward for axis=-1.
The computation includes
  - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
  - Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
          bool LogMode = false>
__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
                                    int batch_size, int stride,
                                    int element_count) {
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
291
  constexpr int kLoops = kDimCeil / kWarpSize;
292
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
293
  constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
294 295
  int element_count_v = element_count / kVSize;
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
296 297 298 299 300 301 302 303
  int local_batches = min(batch_size - first_batch, kBatchSize);

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
304 305 306
  }

  // read data from global memory
307 308 309 310 311 312 313 314
  VecT src_reg[kBatchSize][kLoopsV];
  VecT grad_reg[kBatchSize][kLoopsV];
  VecT k_value;
  for (int s = 0; s < kVSize; s++) {
    reinterpret_cast<T*>(&k_value)[s] = 0.0;
  }
  kps::Init<VecT, kBatchSize * kLoopsV>(&src_reg[0][0], k_value);
  kps::Init<VecT, kBatchSize * kLoopsV>(&grad_reg[0][0], k_value);
315
#pragma unroll
316 317 318 319 320 321 322 323 324
  for (int i = 0; i < kBatchSize; ++i) {
    int flag = i < local_batches ? 1 : 0;
    int ptr = (first_batch + i) * stride;
    const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
    const VecT* grad_v = reinterpret_cast<const VecT*>(&grad[ptr]);
    kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
        &src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag);
    kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
        &grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag);
325 326
  }

327 328 329 330 331 332 333 334 335 336 337 338
  // change T to AccT
  AccT src_tmp[kBatchSize][kLoopsV][kVSize];
  AccT grad_tmp[kBatchSize][kLoopsV][kVSize];
  const T* src_ptr = reinterpret_cast<const T*>(&src_reg[0][0]);
  const T* grad_ptr = reinterpret_cast<const T*>(&grad_reg[0][0]);
  constexpr int kStep = kBatchSize * kLoopsV * kVSize;
  constexpr int kVItem = kLoopsV * kVSize;
  kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
      &src_tmp[0][0][0], &src_ptr[0], DataTransFunctor<T, AccT>());
  kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
      &grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor<T, AccT>());

339 340
  // compute sum
  AccT sum[kBatchSize]{0.0};
341 342 343 344 345 346 347 348
  AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
  AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
  AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
  kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
      &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
  kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>,
              kps::details::ReduceMode::kLocalMode>(
      &sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
349 350
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

351 352 353
  // write result to global memory
  AccT out[kBatchSize][kLoopsV][kVSize];
  T out_tmp[kBatchSize][kLoopsV][kVSize];
354 355 356
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    if (i >= local_batches) break;
357 358 359 360 361 362
    AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
    AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
        &out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
    kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
        &out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor<AccT>());
363
    VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
364 365 366
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
    kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
        &dst_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
  }
}

#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT)                      \
  case Log2Elements:                                                       \
    WarpSoftmaxForward<T, VecT, AccT, Log2Elements,                        \
                       LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
        dst, src, batch_size, stride, element_count);                      \
    break;

/*
  Wrapper of softmax formward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads,
                              const platform::CUDADeviceContext& dev_ctx,
                              T* dst, const T* src, const int batch_size,
                              const int stride, const int element_count,
                              int Log2Elements) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  switch (Log2Elements) {
    SOFTMAX_WARP_FORWARD_CASE(0, AccT);
    SOFTMAX_WARP_FORWARD_CASE(1, AccT);
    SOFTMAX_WARP_FORWARD_CASE(2, AccT);
    SOFTMAX_WARP_FORWARD_CASE(3, AccT);
    SOFTMAX_WARP_FORWARD_CASE(4, AccT);
    SOFTMAX_WARP_FORWARD_CASE(5, AccT);
    SOFTMAX_WARP_FORWARD_CASE(6, AccT);
    SOFTMAX_WARP_FORWARD_CASE(7, AccT);
    SOFTMAX_WARP_FORWARD_CASE(8, AccT);
    SOFTMAX_WARP_FORWARD_CASE(9, AccT);
    default:
      break;
  }
}

#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT)                      \
  case Log2Elements:                                                        \
    WarpSoftmaxBackward<T, VecT, AccT, Log2Elements,                        \
                        LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
        dst, grad, src, batch_size, stride, element_count);                 \
    break;

/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
                               const platform::CUDADeviceContext& dev_ctx,
                               T* dst, const T* grad, const T* src,
                               const int batch_size, const int stride,
                               const int element_count, int Log2Elements) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  switch (Log2Elements) {
    SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(3, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(4, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(5, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(6, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
    default:
      break;
  }
}

#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
                                    const Tensor& x, const int input_axis,
                                    Tensor* out) {
  auto* out_data = out->data<T>();

  auto dims = x.dims();
  const int rank = dims.size();
  const int axis = CanonicalAxis(input_axis, rank);
  const int dim = dims[axis];
  const int N = SizeToAxis(axis, dims);
  const int D = SizeOutAxis(axis, dims);

  constexpr int max_dim = 320;
  constexpr int warps_per_block = 4;

  if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
    const int kDimLog2 = static_cast<int>(log2_ceil(dim));
    const int kDimCeil = 1 << kDimLog2;
    int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
    int batches_per_warp = (kDimCeil <= 32) ? 2 : 1;

    // use 128 threads per block to maximimize gpu utilization
    constexpr int threads_per_block = 128;

    int warps_per_block = (threads_per_block / kWarpSize);
    int batches_per_block = warps_per_block * batches_per_warp;
    int blocks = (N + batches_per_block - 1) / batches_per_block;
    dim3 threads(kWarpSize, warps_per_block, 1);

    // vectorization read/write
    using T4 = typename VecT4<T>::Type;
    using T2 = typename VecT2<T>::Type;
472

473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
    if (dim % 4 == 0) {
      SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, dev_ctx,
                                               out_data, x.data<T>(), N, dim,
                                               dim, kDimLog2);
    } else if (dim % 2 == 0) {
      SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks, threads, dev_ctx,
                                               out_data, x.data<T>(), N, dim,
                                               dim, kDimLog2);
    } else {
      SwitchWarpSoftmaxForward<T, T, LogMode>(blocks, threads, dev_ctx,
                                              out_data, x.data<T>(), N, dim,
                                              dim, kDimLog2);
    }
  } else {
    ScopedTensorDescriptor desc;
    std::vector<int> tensor_dims = {N, dim, D, 1};
    DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
    miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else
    cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#endif

    auto handle = dev_ctx.cudnn_handle();

#ifdef PADDLE_WITH_HIP
    auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                                 : MIOPEN_SOFTMAX_MODE_CHANNEL;
    if (LogMode) {
502
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
503 504 505 506
          handle, platform::CudnnDataType<T>::kOne(), desc_, x.data<T>(),
          platform::CudnnDataType<T>::kZero(), desc_, out_data,
          MIOPEN_SOFTMAX_LOG, mode));
    } else {
507
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
508 509 510 511 512 513 514 515
          handle, platform::CudnnDataType<T>::kOne(), desc_, x.data<T>(),
          platform::CudnnDataType<T>::kZero(), desc_, out_data,
          MIOPEN_SOFTMAX_ACCURATE, mode));
    }
#else
    auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                                 : CUDNN_SOFTMAX_MODE_CHANNEL;
    if (LogMode) {
516
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
517 518 519 520
          handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
          desc_, x.data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
          out_data));
    } else {
521
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
          handle, CUDNN_SOFTMAX_ACCURATE, mode,
          platform::CudnnDataType<T>::kOne(), desc_, x.data<T>(),
          platform::CudnnDataType<T>::kZero(), desc_, out_data));
    }
#endif
  }
}

template <typename T, bool LogMode = false>
void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
                                     const Tensor& out, const Tensor& dout,
                                     const int input_axis, Tensor* dx) {
  auto* dx_data = dx->data<T>();

  auto dims = out.dims();
  const int rank = dims.size();
  const int axis = CanonicalAxis(input_axis, rank);
  const int dim = dims[axis];
  const int N = SizeToAxis(axis, dims);
  const int D = SizeOutAxis(axis, dims);

  constexpr int max_dim = 320;
  constexpr int warps_per_block = 4;

  if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
    const int kDimLog2 = log2_ceil(dim);
    const int kDimCeil = 1 << kDimLog2;
    int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
    int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
    constexpr int threads_per_block = 128;

    int warps_per_block = (threads_per_block / kWarpSize);
    int batches_per_block = warps_per_block * batches_per_warp;
    int blocks = (N + batches_per_block - 1) / batches_per_block;
    dim3 threads(kWarpSize, warps_per_block, 1);

    // vectorization read/write
    using T4 = typename VecT4<T>::Type;
    using T2 = typename VecT2<T>::Type;
    if (dim % 4 == 0) {
      SwitchWarpSoftmaxBackward<T, T4, LogMode>(
          blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N,
          dim, dim, kDimLog2);
    } else if (dim % 2 == 0) {
      SwitchWarpSoftmaxBackward<T, T2, LogMode>(
          blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N,
          dim, dim, kDimLog2);
    } else {
      SwitchWarpSoftmaxBackward<T, T, LogMode>(
          blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N,
          dim, dim, kDimLog2);
    }
  } else {
    ScopedTensorDescriptor desc;
    std::vector<int> tensor_dims = {N, dim, D, 1};
    DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
    miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else
    cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#endif

    auto handle = dev_ctx.cudnn_handle();

#ifdef PADDLE_WITH_HIP
    auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                                 : MIOPEN_SOFTMAX_MODE_CHANNEL;
    if (LogMode) {
590
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
591 592 593 594
          handle, platform::CudnnDataType<T>::kOne(), desc_, out.data<T>(),
          desc_, dout.data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
          dx_data, MIOPEN_SOFTMAX_LOG, mode));
    } else {
595
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
596 597 598 599 600 601 602 603
          handle, platform::CudnnDataType<T>::kOne(), desc_, out.data<T>(),
          desc_, dout.data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
          dx_data, MIOPEN_SOFTMAX_ACCURATE, mode));
    }
#else
    auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                                 : CUDNN_SOFTMAX_MODE_CHANNEL;
    if (LogMode) {
604
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
605 606 607 608
          handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
          desc_, out.data<T>(), desc_, dout.data<T>(),
          platform::CudnnDataType<T>::kZero(), desc_, dx_data));
    } else {
609
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
610 611 612 613 614 615 616 617 618 619
          handle, CUDNN_SOFTMAX_ACCURATE, mode,
          platform::CudnnDataType<T>::kOne(), desc_, out.data<T>(), desc_,
          dout.data<T>(), platform::CudnnDataType<T>::kZero(), desc_, dx_data));
    }
#endif
  }
}

}  // namespace operators
}  // namespace paddle