softmax_gpudnn.h 50.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/backends/gpu/gpu_info.h"
18 19
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
20 21
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
22
#include "paddle/phi/kernels/funcs/aligned_vector.h"
23 24 25 26
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

// See Note [ Why still include the fluid headers? ]
27
#include "paddle/phi/backends/gpu/gpu_device_function.h"
28
#include "paddle/phi/backends/gpu/gpu_dnn.h"
29

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#define MATRIX_SOFTMAX_ALIGN_BYTES 16
#define MATRIX_SOFTMAX_THREAHOLD 100000

#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

#define FIXED_VEC_SIZE_BASE(vec_size, ...) \
  case (vec_size): {                       \
    constexpr auto VecSize = (vec_size);   \
    __VA_ARGS__;                           \
  } break

#define FIXED_BLOCK_DIM(...)                \
  FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

#define FIXED_VEC_SIZE(...)              \
  FIXED_VEC_SIZE_BASE(8, ##__VA_ARGS__); \
  FIXED_VEC_SIZE_BASE(4, ##__VA_ARGS__)

56
namespace phi {
57

58 59
using ScopedTensorDescriptor = phi::backends::gpu::ScopedTensorDescriptor;
using GPUDNNDataLayout = phi::backends::gpu::DataLayout;
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

// Vectorization trait 4 * sizeof(T)
template <typename T>
class VecT4 {};
template <>
class VecT4<double> {
 public:
  using Type = long4;
};
template <>
class VecT4<float> {
 public:
  using Type = int4;
};
template <>
75
class VecT4<phi::dtype::float16> {
76 77 78
 public:
  using Type = int2;
};
79 80 81 82 83
template <>
class VecT4<phi::dtype::bfloat16> {
 public:
  using Type = int2;
};
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

// Vectorization trait 2 * sizeof(T)
template <typename T>
class VecT2 {};
template <>
class VecT2<double> {
 public:
  using Type = int4;
};
template <>
class VecT2<float> {
 public:
  using Type = int2;
};
template <>
99
class VecT2<phi::dtype::float16> {
100 101 102
 public:
  using Type = int;
};
103 104 105 106 107
template <>
class VecT2<phi::dtype::bfloat16> {
 public:
  using Type = int;
};
108

109
static inline int Log2Ceil(int value) {
110 111 112 113 114
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
}

115 116 117 118 119 120 121 122 123 124 125 126 127 128
inline int getBlockSize(int vec_size, uint64_t dim_size) {
  uint64_t block_size = 1;
  uint64_t max_block_size =
      std::min(dim_size / vec_size, static_cast<uint64_t>(1024));

  if (vec_size > 1) {
    max_block_size /= 2;
  }

  while (block_size < (max_block_size)) block_size *= 2;
  block_size = std::max(block_size, static_cast<uint64_t>(32));
  return block_size;
}

129 130 131 132 133 134
template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceSum(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
135
      T sum_val =
136
          phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
137 138 139 140 141 142 143 144 145 146 147
      sum[i] = sum[i] + sum_val;
    }
  }
}

template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceMax(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
148
      T max_val =
149
          phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
150 151 152 153 154
      sum[i] = max(sum[i], max_val);
    }
  }
}

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
template <typename T>
__inline__ __device__ void BlockReduceMax(T* val) {
  static __shared__ T shared[32];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  WarpReduceMax<T, 1, 32>(val);

  if (lane == 0) shared[wid] = *val;

  __syncthreads();

  int block_span = (blockDim.x + warpSize - 1) >> 5;
  *val = (lane < block_span) ? shared[lane] : -1e10f;
  WarpReduceMax<T, 1, 32>(val);
}

template <typename T>
__inline__ __device__ void BlockReduceSum(T* val) {
  static __shared__ T shared[32];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  WarpReduceSum<T, 1, 32>(val);

  __syncthreads();
  if (lane == 0) shared[wid] = *val;

  __syncthreads();

  int block_span = (blockDim.x + warpSize - 1) >> 5;
  *val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
  WarpReduceSum<T, 1, 32>(val);
}

190 191 192 193 194 195 196 197 198
template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
  inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }

  __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
    return max(a, b);
  }
};

199 200 201 202 203 204 205 206
template <typename T, typename AccT>
struct MaxFunctor {
  __device__ __forceinline__ AccT operator()(const AccT& max_v,
                                             const T& v) const {
    return max(max_v, static_cast<AccT>(v));
  }
};

207
template <typename Tx, typename Ty = Tx>
208
struct ExpFunctor {
209
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
210
    return static_cast<Ty>(std::exp(x));
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
  }
};

template <typename Tx, typename Ty = Tx>
struct ExpMulFunctor {
  HOSTDEVICE inline ExpMulFunctor() { y = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x) * y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnarySubFunctor {
  HOSTDEVICE inline UnarySubFunctor() { y = static_cast<Tx>(0.0f); }

  HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnaryLogFunctor {
  HOSTDEVICE inline UnaryLogFunctor() {}

  HOSTDEVICE explicit inline UnaryLogFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::log(x));
  }
};

template <typename Tx, typename Ty>
struct DataTransFunctor {
  HOSTDEVICE inline DataTransFunctor() {}

  HOSTDEVICE explicit inline DataTransFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return x == -std::numeric_limits<Tx>::infinity()
               ? -std::numeric_limits<Ty>::infinity()
               : static_cast<Ty>(x);
  }
};

template <typename Tx, typename Ty = Tx>
struct UnaryDivFunctor {
  HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x * n_inv);
  }

 private:
  Tx n_inv;
};

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
template <typename Tx, typename Ty = Tx>
struct SoftmaxForwardFunctor {
  HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum)
      : max(max), sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x - max) / sum);
  }

 private:
  Tx max;
  Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxBackwardFunctor {
  HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
    return static_cast<Ty>(out * (grad_out - sum));
  }

 private:
  Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxForwardFunctor {
  HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum)
      : max(max), log_sum(std::log(sum)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - max - log_sum);
  }

 private:
  Tx max;
  Tx log_sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxBackwardFunctor {
  HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
    return static_cast<Ty>(grad_out - std::exp(out) * sum);
  }

 private:
  Tx sum;
};

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
template <typename T, typename AccT>
struct SumExpFunctor {
  HOSTDEVICE inline SumExpFunctor(AccT v) : max_v(v) {}

  HOSTDEVICE inline AccT operator()(AccT sum, T v) const {
    return sum + std::exp(static_cast<AccT>(v) - max_v);
  }

 private:
  AccT max_v;
};

template <template <typename, typename> class Reduction,
          typename T,
          typename AccT,
          int VecSize>
__device__ __forceinline__ AccT
C
carryyu 已提交
349
ThreadVecReduce(T* data,
350
                int dim_size,
C
carryyu 已提交
351
                const int shift,
352 353 354 355
                const Reduction<T, AccT>& functor,
                AccT default_value) {
  using VecT = phi::AlignedVector<T, VecSize>;
  AccT thread_val = default_value;
C
carryyu 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368

  // for memory align, handle the unaligned data in first block.
  int offset = threadIdx.x;
  if (shift > 0) {
    data -= shift;
    dim_size += shift;
    if (offset >= shift) {
      thread_val = functor(thread_val, data[offset]);
    }
    dim_size -= blockDim.x;
    data += blockDim.x;
  }

369 370 371 372 373
  const int last = dim_size % (VecSize * blockDim.x);

  T v[VecSize];
  VecT* value = reinterpret_cast<VecT*>(&v);

C
carryyu 已提交
374 375
  for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
    *value = reinterpret_cast<VecT*>(data)[offset];
376 377 378 379 380 381
#pragma unroll
    for (int i = 0; i < VecSize; i++) {
      thread_val = functor(thread_val, v[i]);
    }
  }

C
carryyu 已提交
382 383
  offset = dim_size - last + threadIdx.x;
  for (; offset < dim_size; offset += blockDim.x) {
384 385 386 387 388 389 390 391 392
    thread_val = functor(thread_val, data[offset]);
  }
  return thread_val;
}

template <template <typename, typename> class Reduction,
          typename T,
          typename AccT,
          int VecSize>
C
carryyu 已提交
393 394 395 396 397
__device__ __forceinline__ void ThreadVecWriteVec(T* out,
                                                  T* input,
                                                  int dim_size,
                                                  const int shift,
                                                  Reduction<AccT, T> functor) {
398 399
  using VecT = phi::AlignedVector<T, VecSize>;

C
carryyu 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413
  // for memory align, handle the unaligned data in first block.
  int offset = threadIdx.x;
  if (shift > 0) {
    input -= shift;
    out -= shift;
    dim_size += shift;
    if (offset >= shift) {
      out[offset] = functor(static_cast<AccT>(input[offset]));
    }
    dim_size -= blockDim.x;
    input += blockDim.x;
    out += blockDim.x;
  }

414 415 416 417 418 419 420 421
  const int last = dim_size % (VecSize * blockDim.x);

  T in_v[VecSize];
  VecT* in_value = reinterpret_cast<VecT*>(&in_v);

  T out_v[VecSize];
  VecT* out_value = reinterpret_cast<VecT*>(&out_v);

C
carryyu 已提交
422 423
  for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
    *in_value = reinterpret_cast<VecT*>(input)[offset];
424 425 426 427 428 429 430
#pragma unroll
    for (int i = 0; i < VecSize; i++) {
      out_v[i] = functor(static_cast<AccT>(in_v[i]));
    }
    reinterpret_cast<VecT*>(out)[offset] = *out_value;
  }

C
carryyu 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
  offset = dim_size - last + threadIdx.x;
  // the tail
  for (; offset < dim_size; offset += blockDim.x) {
    out[offset] = functor(static_cast<AccT>(input[offset]));
  }
}

template <template <typename, typename> class Reduction,
          typename T,
          typename AccT,
          int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out,
                                               T* input,
                                               int dim_size,
                                               Reduction<AccT, T> functor) {
  const int last = dim_size % (VecSize * blockDim.x);

  for (int offset = threadIdx.x; offset < dim_size - last;
       offset += blockDim.x * VecSize) {
#pragma unroll
    for (int i = 0; i < VecSize; i++) {
      out[offset + i * blockDim.x] =
          functor(static_cast<AccT>(input[offset + i * blockDim.x]));
    }
  }

  // the tail
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
  for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
       offset += blockDim.x) {
    out[offset] = functor(static_cast<AccT>(input[offset]));
  }
}

template <typename T,
          typename AccT,
          typename IndexType,
          int BatchSize,
          int VecSize,
          bool LogMode = false>
__global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
  using VecT = phi::AlignedVector<T, VecSize>;

  int bid = blockIdx.x;
C
carryyu 已提交
474
  T* batch_input = const_cast<T*>(src) + bid * dim_size;
475 476
  T* batch_output = softmax + bid * dim_size;

C
carryyu 已提交
477 478 479 480 481
  const int input_align_shift =
      ((uint64_t)batch_input) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
  const int output_align_shift =
      ((uint64_t)batch_output) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);

482 483 484 485
  // get max value
  AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>(
      batch_input,
      dim_size,
C
carryyu 已提交
486
      input_align_shift,
487 488 489 490 491 492 493 494
      MaxFunctor<T, AccT>(),
      std::numeric_limits<AccT>::min());
  BlockReduceMax<AccT>(&thread_max);

  // get exp value and sum all
  AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>(
      batch_input,
      dim_size,
C
carryyu 已提交
495
      input_align_shift,
496 497 498 499 500 501
      SumExpFunctor<T, AccT>(thread_max),
      static_cast<AccT>(0.));
  BlockReduceSum<AccT>(&thread_exp);

  // write data to softmax_output according to the LogMode
  if (LogMode) {
502
    LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
C
carryyu 已提交
503 504 505 506 507 508 509
    if (input_align_shift == output_align_shift) {
      ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, input_align_shift, reduction);
    } else {
      ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, reduction);
    }
510 511
  } else {
    SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
C
carryyu 已提交
512 513 514 515 516 517 518
    if (input_align_shift == output_align_shift) {
      ThreadVecWriteVec<SoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, input_align_shift, reduction);
    } else {
      ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, reduction);
    }
519 520 521
  }
}

522 523 524 525 526 527 528 529 530 531
/*
Core function of computing softmax forward for axis=-1.
The computation includes
  - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
  - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  - Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
532 533 534
template <typename T,
          typename VecT,
          typename AccT,
535
          typename IndexType,
536
          int Log2Elements,
537
          bool LogMode = false>
538 539
__global__ void WarpSoftmaxForward(T* softmax,
                                   const T* src,
540 541 542 543 544 545 546 547 548 549 550 551 552 553
                                   const IndexType batch_size,
                                   const IndexType stride,
                                   const IndexType element_count) {
  constexpr IndexType kDimCeil = 1 << Log2Elements;
  constexpr IndexType kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr IndexType kVSize = sizeof(VecT) / sizeof(T);
  constexpr IndexType kLoops = kDimCeil / kWarpSize;
  constexpr IndexType kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
  constexpr IndexType kBatchSize = (kDimCeil <= 32) ? 2 : 1;
  IndexType first_batch =
      (static_cast<IndexType>(blockDim.y) * blockIdx.x + threadIdx.y) *
      kBatchSize;
  constexpr IndexType kStep = kBatchSize * kLoopsV * kVSize;
  constexpr IndexType kVItem = kLoopsV * kVSize;
554 555
  constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
  using kMode = kps::details::ReduceMode;
556 557

  // max index to read
558
  IndexType idx_max_v[kBatchSize];
559
#pragma unroll
560 561
  for (IndexType i = 0; i < kBatchSize; i++) {
    IndexType idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
562 563 564
    idx_max_v[i] = idx_max / kVSize;
  }

F
Feng Xing 已提交
565
  // data src
566 567 568 569 570 571 572 573
  // src_data: the raw data form global memory
  // sub_data: store the data obtained by (src_data - max), used by log_softmax
  // exp_data: store the data obtained by (exp(sub_data)), used by softmax
  T src_data[kBatchSize][kLoopsV][kVSize];
  AccT sub_data[kBatchSize][kLoopsV][kVSize];
  AccT exp_data[kBatchSize][kLoopsV][kVSize];
  kps::Init<AccT, kStep>(&sub_data[0][0][0], kLowInf);
  kps::Init<T, kStep>(&src_data[0][0][0], -std::numeric_limits<T>::infinity());
F
Feng Xing 已提交
574 575 576 577 578 579 580 581 582 583 584 585

  // data dst
  T out_tmp[kBatchSize][kLoopsV][kVSize];

  // max value
  AccT max[kBatchSize];
  kps::Init<AccT, kBatchSize>(&max[0], kLowInf);

  // sum value
  AccT sum[kBatchSize] = {0};

// read data from global memory
586
#pragma unroll
587
  for (IndexType i = 0; i < kBatchSize; ++i) {
F
Feng Xing 已提交
588 589
    const VecT* src_v =
        reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
590
    VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]);
591
    kps::ReadData<VecT, VecT, kLoopsV, 1, true>(
592
        &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
593
    kps::ElementwiseUnary<T, AccT, kVItem, 1, DataTransFunctor<T, AccT>>(
594
        &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor<T, AccT>());
595 596
  }

597
  // compute max
598 599 600 601 602
  kps::Reduce<AccT,
              kVItem,
              kBatchSize,
              ReduceMaxFunctor<AccT>,
              kMode::kLocalMode>(
603
      &max[0], &sub_data[0][0][0], ReduceMaxFunctor<AccT>(), true);
604
  WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
605

606 607
// compute sum
#pragma unroll
608
  for (IndexType i = 0; i < kBatchSize; ++i) {
609
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>(
610
        &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i]));
611
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpFunctor<AccT>>(
612
        &exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor<AccT>());
613
  }
614 615 616 617 618
  kps::Reduce<AccT,
              kVItem,
              kBatchSize,
              kps::AddFunctor<AccT>,
              kMode::kLocalMode>(
619
      &sum[0], &exp_data[0][0][0], kps::AddFunctor<AccT>(), true);
620 621
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

F
Feng Xing 已提交
622
// write data to global memory
623
#pragma unroll
624
  for (IndexType i = 0; i < kBatchSize; ++i) {
F
Feng Xing 已提交
625 626 627
    VecT* softmax_v =
        reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
628
    if (LogMode) {
629
      kps::ElementwiseUnary<AccT, T, kVItem, 1, UnarySubFunctor<AccT>>(
630
          &out_tmp[i][0][0],
631
          &sub_data[i][0][0],
632 633
          UnarySubFunctor<AccT>(std::log(sum[i])));
    } else {
634
      kps::ElementwiseUnary<AccT, T, kVItem, 1, UnaryDivFunctor<AccT>>(
635
          &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
636
    }
637
    kps::WriteData<VecT, VecT, kLoopsV, 1, true>(
638
        &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
639 640 641 642 643 644 645 646 647 648 649 650
  }
}

/*
Core function of computing softmax backward for axis=-1.
The computation includes
  - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
  - Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
651 652 653 654
template <typename T,
          typename VecT,
          typename AccT,
          int Log2Elements,
655
          bool LogMode = false>
656 657 658 659 660
__global__ void WarpSoftmaxBackward(T* dst,
                                    const T* grad,
                                    const T* src,
                                    int batch_size,
                                    int stride,
661 662 663 664
                                    int element_count) {
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
665
  constexpr int kLoops = kDimCeil / kWarpSize;
666
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
667
  constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
668 669
  int element_count_v = element_count / kVSize;
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
670 671 672 673 674 675 676 677
  int local_batches = min(batch_size - first_batch, kBatchSize);

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
678 679 680
  }

  // read data from global memory
681 682 683 684 685 686 687 688
  VecT src_reg[kBatchSize][kLoopsV];
  VecT grad_reg[kBatchSize][kLoopsV];
  VecT k_value;
  for (int s = 0; s < kVSize; s++) {
    reinterpret_cast<T*>(&k_value)[s] = 0.0;
  }
  kps::Init<VecT, kBatchSize * kLoopsV>(&src_reg[0][0], k_value);
  kps::Init<VecT, kBatchSize * kLoopsV>(&grad_reg[0][0], k_value);
689
#pragma unroll
690 691 692 693 694
  for (int i = 0; i < kBatchSize; ++i) {
    int flag = i < local_batches ? 1 : 0;
    int ptr = (first_batch + i) * stride;
    const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
    const VecT* grad_v = reinterpret_cast<const VecT*>(&grad[ptr]);
695
    kps::ReadData<VecT, VecT, kLoopsV, 1, true>(
696
        &src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag);
697
    kps::ReadData<VecT, VecT, kLoopsV, 1, true>(
698
        &grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag);
699 700
  }

701 702 703 704 705 706 707
  // change T to AccT
  AccT src_tmp[kBatchSize][kLoopsV][kVSize];
  AccT grad_tmp[kBatchSize][kLoopsV][kVSize];
  const T* src_ptr = reinterpret_cast<const T*>(&src_reg[0][0]);
  const T* grad_ptr = reinterpret_cast<const T*>(&grad_reg[0][0]);
  constexpr int kStep = kBatchSize * kLoopsV * kVSize;
  constexpr int kVItem = kLoopsV * kVSize;
708
  kps::ElementwiseUnary<T, AccT, kStep, 1, DataTransFunctor<T, AccT>>(
709
      &src_tmp[0][0][0], &src_ptr[0], DataTransFunctor<T, AccT>());
710
  kps::ElementwiseUnary<T, AccT, kStep, 1, DataTransFunctor<T, AccT>>(
711 712
      &grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor<T, AccT>());

713 714
  // compute sum
  AccT sum[kBatchSize]{0.0};
715 716 717
  AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
  AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
  AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
718 719 720 721 722 723 724 725
  if (LogMode) {
    kps::Reduce<AccT,
                kVItem,
                kBatchSize,
                kps::AddFunctor<AccT>,
                kps::details::ReduceMode::kLocalMode>(
        &sum[0], &grad_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
  } else {
726
    kps::ElementwiseBinary<AccT, AccT, kStep, 1, kps::MulFunctor<AccT>>(
727 728 729 730 731 732 733 734
        &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
    kps::Reduce<AccT,
                kVItem,
                kBatchSize,
                kps::AddFunctor<AccT>,
                kps::details::ReduceMode::kLocalMode>(
        &sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
  }
735 736
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

737 738 739
  // write result to global memory
  AccT out[kBatchSize][kLoopsV][kVSize];
  T out_tmp[kBatchSize][kLoopsV][kVSize];
740 741 742
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    if (i >= local_batches) break;
743 744
    AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
    AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
745
    if (LogMode) {
746
      kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpMulFunctor<AccT>>(
747
          &out[i][0][0], &srcptr[0], ExpMulFunctor<AccT>(sum[i]));
748
      kps::ElementwiseBinary<AccT, T, kVItem, 1, kps::SubFunctor<AccT>>(
749 750 751 752 753
          &out_tmp[i][0][0],
          &gradptr[0],
          &out[i][0][0],
          kps::SubFunctor<AccT>());
    } else {
754
      kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>(
755
          &out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
756
      kps::ElementwiseBinary<AccT, T, kVItem, 1, kps::MulFunctor<AccT>>(
757 758 759 760 761
          &out_tmp[i][0][0],
          &srcptr[0],
          &out[i][0][0],
          kps::MulFunctor<AccT>());
    }
762
    VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
763
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
764
    kps::WriteData<VecT, VecT, kLoopsV, 1, true>(
765
        &dst_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
766 767 768
  }
}

769 770 771 772 773
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT)                   \
  case Log2Elements:                                                    \
    WarpSoftmaxForward<T, VecT, AccT, IndexType, Log2Elements, LogMode> \
        <<<blocks, threads, 0, dev_ctx.stream()>>>(                     \
            dst, src, batch_size, stride, element_count);               \
774 775 776 777 778
    break;

/*
  Wrapper of softmax formward with template instantiation on size of input.
*/
779 780
template <typename T, typename VecT, typename IndexType, bool LogMode>
void SwitchWarpSoftmaxForward(const IndexType blocks,
781 782 783 784
                              const dim3 threads,
                              const GPUContext& dev_ctx,
                              T* dst,
                              const T* src,
785 786 787 788
                              const IndexType batch_size,
                              const IndexType stride,
                              const IndexType element_count,
                              IndexType Log2Elements) {
789
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
790 791 792 793 794 795 796 797 798 799 800
  switch (Log2Elements) {
    SOFTMAX_WARP_FORWARD_CASE(0, AccT);
    SOFTMAX_WARP_FORWARD_CASE(1, AccT);
    SOFTMAX_WARP_FORWARD_CASE(2, AccT);
    SOFTMAX_WARP_FORWARD_CASE(3, AccT);
    SOFTMAX_WARP_FORWARD_CASE(4, AccT);
    SOFTMAX_WARP_FORWARD_CASE(5, AccT);
    SOFTMAX_WARP_FORWARD_CASE(6, AccT);
    SOFTMAX_WARP_FORWARD_CASE(7, AccT);
    SOFTMAX_WARP_FORWARD_CASE(8, AccT);
    SOFTMAX_WARP_FORWARD_CASE(9, AccT);
801
    SOFTMAX_WARP_FORWARD_CASE(10, AccT);
802 803 804 805 806
    default:
      break;
  }
}

807 808 809 810 811
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT)          \
  case Log2Elements:                                            \
    WarpSoftmaxBackward<T, VecT, AccT, Log2Elements, LogMode>   \
        <<<blocks, threads, 0, dev_ctx.stream()>>>(             \
            dst, grad, src, batch_size, stride, element_count); \
812 813 814 815 816 817
    break;

/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
818 819 820 821 822 823 824 825 826 827 828
void SwitchWarpSoftmaxBackward(const int blocks,
                               const dim3 threads,
                               const GPUContext& dev_ctx,
                               T* dst,
                               const T* grad,
                               const T* src,
                               const int batch_size,
                               const int stride,
                               const int element_count,
                               int Log2Elements) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
829 830 831 832 833 834 835 836 837 838 839
  switch (Log2Elements) {
    SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(3, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(4, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(5, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(6, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
840
    SOFTMAX_WARP_BACKWARD_CASE(10, AccT);
841 842 843 844 845 846 847 848
    default:
      break;
  }
}

#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

849 850 851 852 853
/**
 * <NormalSoftmaxKernel>
 * Better performence when axis != -1
 */

854 855 856 857
static void GetGridDim(
    int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) {
  int device_id = phi::backends::gpu::GetCurrentDeviceId();
  int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id);
858
  int max_threads_per_mp =
859
      phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id);
860 861 862 863 864 865 866 867 868 869 870 871 872 873
  int max_threads = max_threads_per_mp * max_mp;
  int num_threads = block.x * block.y;
  int max_num_blocks = max_threads / num_threads;

  int grid_x = (low_dim + block.x - 1) / block.x;
  grid_x = std::min(grid_x, max_num_blocks);
  int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
  grid_y = std::min(grid_y, high_dim);
  grid->x = grid_x;
  grid->y = grid_y;
}

static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
  constexpr int max_num_threads = 1024;
874 875
  int block_x = 1 << Log2Ceil(low_dim);
  int block_y = 1 << Log2Ceil(mid_dim);
876 877 878 879 880
  block->x = std::min(block_x, 32);
  block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
  block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
}

881 882
static void GetLaunchConfig(
    int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) {
883 884 885 886
  GetBlockDim(mid_dim, low_dim, block);
  GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
}

887 888
template <typename T,
          typename AccT,
889 890
          template <typename, typename>
          class Functor>
891 892
__global__ void NormalSoftmaxForward(
    T* output, const T* input, int high_dim, int mid_dim, int low_dim) {
893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
  using kMode = kps::details::ReduceMode;
  const int high_stride = mid_dim * low_dim;
  const int mid_stride = low_dim;
  for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
    for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
         low_id += blockDim.x * gridDim.x) {
      const int input_offset = high_id * high_stride + low_id;

      // 1. reduce max
      AccT max_value = -std::numeric_limits<AccT>::infinity();
      AccT value = -std::numeric_limits<AccT>::infinity();
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
        max_value = kps::MaxFunctor<AccT>()(max_value, value);
      }

      if (blockDim.y > 1) {
910
        kps::Reduce<AccT, 1, 1, kps::MaxFunctor<AccT>, kMode::kGlobalMode>(
911 912 913 914 915 916 917 918 919 920
            &max_value, &max_value, kps::MaxFunctor<AccT>(), false);
      }

      // 2. reduce sum
      AccT sum = 0;
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
        sum += std::exp(value - max_value);
      }
      if (blockDim.y > 1) {
921
        kps::Reduce<AccT, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
922 923 924 925 926 927 928 929 930 931 932 933 934
            &sum, &sum, kps::AddFunctor<AccT>(), false);
      }

      // 3. (log)softmax
      Functor<AccT, T> functor(max_value, sum);
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        int data_offset = input_offset + mid_id * mid_stride;
        output[data_offset] = functor(static_cast<AccT>(input[data_offset]));
      }
    }
  }
}

935 936
template <typename T,
          typename AccT,
937 938
          template <typename, typename>
          class Functor,
939
          bool LogMode>
940 941 942 943 944 945
__global__ void NormalSoftmaxBackward(T* input_grad,
                                      const T* output_grad,
                                      const T* output,
                                      int high_dim,
                                      int mid_dim,
                                      int low_dim) {
946 947 948 949 950 951 952 953 954 955
  using kMode = kps::details::ReduceMode;
  const int high_stride = mid_dim * low_dim;
  const int mid_stride = low_dim;
  for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
    for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
         low_id += blockDim.x * gridDim.x) {
      const int grad_offset = high_id * high_stride + low_id;

      // 1. reduce sum
      AccT sum = 0;
956 957 958 959 960 961 962 963 964 965 966
      if (LogMode) {
        for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
          int data_offset = grad_offset + mid_id * mid_stride;
          sum += static_cast<AccT>(output_grad[data_offset]);
        }
      } else {
        for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
          int data_offset = grad_offset + mid_id * mid_stride;
          sum += static_cast<AccT>(output_grad[data_offset]) *
                 static_cast<AccT>(output[data_offset]);
        }
967 968
      }
      if (blockDim.y > 1) {
969
        kps::Reduce<AccT, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
970 971 972 973 974 975 976 977 978 979 980 981 982 983 984
            &sum, &sum, kps::AddFunctor<AccT>(), false);
      }

      // 2. (log)softmax backward
      Functor<AccT, T> functor(sum);
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        int data_offset = grad_offset + mid_id * mid_stride;
        input_grad[data_offset] =
            functor(static_cast<AccT>(output_grad[data_offset]),
                    static_cast<AccT>(output[data_offset]));
      }
    }
  }
}

985
template <typename T, bool LogMode = false>
986 987 988 989 990 991 992
void LaunchNormalSoftmaxForward(const GPUContext& dev_ctx,
                                T* output_data,
                                const T* input_data,
                                int high_dim,
                                int mid_dim,
                                int low_dim) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
993 994 995
  dim3 grid, block;
  GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
  if (LogMode) {
996 997 998
    NormalSoftmaxForward<T, AccT, LogSoftmaxForwardFunctor>
        <<<grid, block, 0, dev_ctx.stream()>>>(
            output_data, input_data, high_dim, mid_dim, low_dim);
999
  } else {
1000 1001 1002
    NormalSoftmaxForward<T, AccT, SoftmaxForwardFunctor>
        <<<grid, block, 0, dev_ctx.stream()>>>(
            output_data, input_data, high_dim, mid_dim, low_dim);
1003 1004 1005
  }
}

1006
template <typename T, bool LogMode = false>
1007 1008 1009 1010 1011 1012 1013 1014
void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
                                 T* input_grad_data,
                                 const T* output_grad_data,
                                 const T* output_data,
                                 int high_dim,
                                 int mid_dim,
                                 int low_dim) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
1015 1016 1017
  dim3 grid, block;
  GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
  if (LogMode) {
1018 1019 1020 1021 1022 1023 1024
    NormalSoftmaxBackward<T, AccT, LogSoftmaxBackwardFunctor, LogMode>
        <<<grid, block, 0, dev_ctx.stream()>>>(input_grad_data,
                                               output_grad_data,
                                               output_data,
                                               high_dim,
                                               mid_dim,
                                               low_dim);
1025
  } else {
1026 1027 1028 1029 1030 1031 1032
    NormalSoftmaxBackward<T, AccT, SoftmaxBackwardFunctor, LogMode>
        <<<grid, block, 0, dev_ctx.stream()>>>(input_grad_data,
                                               output_grad_data,
                                               output_data,
                                               high_dim,
                                               mid_dim,
                                               low_dim);
1033 1034 1035
  }
}

1036 1037 1038 1039 1040 1041
template <typename T = int>
static std::vector<T> GetSoftmaxTensorDims(const phi::DDim& dims,
                                           const int axis) {
  auto dim = static_cast<T>(dims[axis]);
  auto N = phi::funcs::SizeToAxis<T>(axis, dims);
  auto D = phi::funcs::SizeOutAxis<T>(axis, dims);
1042 1043 1044 1045 1046
  return {N, dim, D, 1};
}

template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
1047
                               const T* x_data,
1048
                               const int axis,
1049
                               const int rank,
1050
                               const bool log_mode,
1051 1052
                               const std::vector<int>& tensor_dims,
                               T* out_data) {
1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
  auto handle = dev_ctx.cudnn_handle();
  GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

  ScopedTensorDescriptor scoped_desc;
#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t desc =
      scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                               : MIOPEN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE;
1063
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxForward_V2(
1064
      handle,
1065
      phi::backends::gpu::CudnnDataType<T>::kOne(),
1066
      desc,
1067
      x_data,
1068
      phi::backends::gpu::CudnnDataType<T>::kZero(),
1069 1070 1071 1072 1073 1074 1075 1076 1077
      desc,
      out_data,
      algo,
      mode));
#else
  cudnnTensorDescriptor_t desc = scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                               : CUDNN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
1078
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxForward(
1079 1080 1081
      handle,
      algo,
      mode,
1082
      phi::backends::gpu::CudnnDataType<T>::kOne(),
1083
      desc,
1084
      x_data,
1085
      phi::backends::gpu::CudnnDataType<T>::kZero(),
1086 1087 1088 1089 1090
      desc,
      out_data));
#endif
}

1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
template <typename T>
void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
                                     const DenseTensor& x,
                                     const int axis,
                                     const bool log_mode,
                                     DenseTensor* out) {
  auto* out_data = out->data<T>();
  auto* x_data = x.data<T>();
  const int rank = x.dims().size();

  std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
  int64_t remaining = tensor_dims[0];
  int dim = tensor_dims[1];
  int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
  int offset = batch_size * dim;
  while (remaining > 0) {
    tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
    SoftmaxForwardCudnnKernel<T>(
        dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
    x_data += offset;
    out_data += offset;
    remaining -= batch_size;
  }
}

1116 1117
template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
1118 1119
                                const T* out_data,
                                const T* dout_data,
1120
                                const int axis,
1121
                                const int rank,
1122
                                const bool log_mode,
1123 1124
                                const std::vector<int>& tensor_dims,
                                T* dx_data) {
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134
  auto handle = dev_ctx.cudnn_handle();
  GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

  ScopedTensorDescriptor scoped_desc;
#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t desc =
      scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                               : MIOPEN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE;
1135 1136
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxBackward_V2(
      handle,
1137
      phi::backends::gpu::CudnnDataType<T>::kOne(),
1138 1139 1140 1141
      desc,
      out_data,
      desc,
      dout_data,
1142
      phi::backends::gpu::CudnnDataType<T>::kZero(),
1143 1144 1145 1146
      desc,
      dx_data,
      algo,
      mode));
1147 1148 1149 1150 1151
#else
  cudnnTensorDescriptor_t desc = scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                               : CUDNN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
1152
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxBackward(
1153 1154 1155
      handle,
      algo,
      mode,
1156
      phi::backends::gpu::CudnnDataType<T>::kOne(),
1157
      desc,
1158
      out_data,
1159
      desc,
1160
      dout_data,
1161
      phi::backends::gpu::CudnnDataType<T>::kZero(),
1162 1163 1164 1165 1166
      desc,
      dx_data));
#endif
}

1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200
template <typename T>
void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
                                      const DenseTensor& out,
                                      const DenseTensor& dout,
                                      const int axis,
                                      const bool log_mode,
                                      DenseTensor* dx) {
  auto* dx_data = dx->data<T>();
  auto* out_data = out.data<T>();
  auto* dout_data = dout.data<T>();
  int rank = out.dims().size();

  std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
  int64_t remaining = tensor_dims[0];
  int dim = tensor_dims[1];
  int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
  int offset = batch_size * dim;
  while (remaining > 0) {
    tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
    SoftmaxBackwardCudnnKernel<T>(dev_ctx,
                                  out_data,
                                  dout_data,
                                  axis,
                                  rank,
                                  log_mode,
                                  tensor_dims,
                                  dx_data);
    out_data += offset;
    dout_data += offset;
    dx_data += offset;
    remaining -= batch_size;
  }
}

1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224
template <typename T, typename IndexType, bool LogMode>
void LaunchKeMatrixSoftmaxForwardKernel(
    const GPUContext& dev_ctx, T* out, const T* input, int N, int dim_size) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
  const int vec_size = MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
  switch (getBlockSize(vec_size, dim_size)) {
    FIXED_BLOCK_DIM(switch (vec_size) {
      FIXED_VEC_SIZE(
          KeMatrixSoftmaxForward<T,
                                 AccT,
                                 IndexType,
                                 kBlockDim,
                                 VecSize,
                                 LogMode>
          <<<N, kBlockDim, 0, dev_ctx.stream()>>>(out, input, dim_size));
      default:
        break;
    });
    default:
      PADDLE_THROW(
          errors::Fatal("the input dim has error in the softmax cuda kernel."));
  }
}

1225 1226
#if CUDNN_VERSION < 8100
template <>
1227
inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
1228 1229 1230 1231 1232 1233 1234 1235 1236 1237
    const GPUContext& dev_ctx,
    const DenseTensor& x,
    const int axis,
    const bool log_mode,
    DenseTensor* out) {
  PADDLE_THROW(errors::Unavailable(
      "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
      "8100."));
}
template <>
1238
inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250
    const GPUContext& dev_ctx,
    const DenseTensor& out,
    const DenseTensor& dout,
    const int axis,
    const bool log_mode,
    DenseTensor* dx) {
  PADDLE_THROW(errors::Unavailable(
      "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
      "8100."));
}
#endif

1251
template <typename T>
1252 1253 1254
bool UseCudnnSoftmax(const GPUContext& ctx,
                     int64_t softmax_dim,
                     bool last_dim) {
1255 1256 1257 1258 1259 1260 1261 1262
  bool cudnn_available = ctx.cudnn_handle();
  if (!ctx.cudnn_handle()) {
    if (std::is_same<T, phi::dtype::bfloat16>::value) {
#if CUDNN_VERSION < 8100
      cudnn_available = false;
#endif
    }
  }
1263
  constexpr int max_dim = 1024;
1264
  if (!cudnn_available || !last_dim ||
1265 1266
      (softmax_dim <= max_dim && sizeof(T) <= 4) ||
      softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) {
1267 1268 1269 1270 1271 1272
    return false;
  } else {
    return true;
  }
}

1273 1274 1275 1276 1277
template <typename T, typename IndexType, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
                                        const DenseTensor& x,
                                        const int input_axis,
                                        DenseTensor* out) {
1278 1279
  auto* out_data = out->data<T>();

1280 1281
  int rank = x.dims().size();
  int axis = phi::funcs::CanonicalAxis(input_axis, rank);
1282 1283 1284 1285
  std::vector<IndexType> tensor_dims =
      GetSoftmaxTensorDims<IndexType>(x.dims(), axis);
  IndexType N = tensor_dims[0];
  IndexType dim = tensor_dims[1];
1286
  int D = tensor_dims[2];
1287

1288 1289
  if (D == 1) {
    if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
1290 1291 1292 1293 1294
      if (dim >= MATRIX_SOFTMAX_THREAHOLD) {
        LaunchKeMatrixSoftmaxForwardKernel<T, IndexType, LogMode>(
            dev_ctx, out_data, x.data<T>(), N, dim);
        return;
      }
1295
      int dim_log2 = static_cast<int>(Log2Ceil(dim));
1296
      IndexType dim_ceil = 1 << dim_log2;
1297 1298 1299 1300 1301 1302 1303 1304
      int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
      int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;

      // use 128 threads per block to maximimize gpu utilization
      constexpr int threads_per_block = 128;

      int warps_per_block = (threads_per_block / warp_size);
      int batches_per_block = warps_per_block * batches_per_warp;
1305
      IndexType blocks = (N + batches_per_block - 1) / batches_per_block;
1306 1307 1308 1309 1310 1311
      dim3 threads(warp_size, warps_per_block, 1);

      // vectorization read/write
      using T4 = typename VecT4<T>::Type;
      using T2 = typename VecT2<T>::Type;

1312
      if (std::is_same<T, float>::value) {
1313 1314 1315 1316 1317 1318 1319 1320 1321
        SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
                                                           threads,
                                                           dev_ctx,
                                                           out_data,
                                                           x.data<T>(),
                                                           N,
                                                           dim,
                                                           dim,
                                                           dim_log2);
1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353
      } else {
        if (dim % 4 == 0) {
          SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
                                                              threads,
                                                              dev_ctx,
                                                              out_data,
                                                              x.data<T>(),
                                                              N,
                                                              dim,
                                                              dim,
                                                              dim_log2);
        } else if (dim % 2 == 0) {
          SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
                                                              threads,
                                                              dev_ctx,
                                                              out_data,
                                                              x.data<T>(),
                                                              N,
                                                              dim,
                                                              dim,
                                                              dim_log2);
        } else {
          SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
                                                             threads,
                                                             dev_ctx,
                                                             out_data,
                                                             x.data<T>(),
                                                             N,
                                                             dim,
                                                             dim,
                                                             dim_log2);
        }
1354
      }
1355
    } else {
1356
      LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
1357
    }
1358
  } else {
1359 1360
    LaunchNormalSoftmaxForward<T, LogMode>(
        dev_ctx, out_data, x.data<T>(), N, dim, D);
1361 1362 1363
  }
}

1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
                                    const DenseTensor& x,
                                    const int input_axis,
                                    DenseTensor* out) {
  if (x.numel() >= std::numeric_limits<int32_t>::max()) {
    SoftmaxForwardCUDAKernelDriverImpl<T, int64_t, LogMode>(
        dev_ctx, x, input_axis, out);
  } else {
    SoftmaxForwardCUDAKernelDriverImpl<T, int32_t, LogMode>(
        dev_ctx, x, input_axis, out);
  }
}

1378
template <typename T, bool LogMode = false>
1379 1380 1381 1382 1383
void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
                                     const DenseTensor& out,
                                     const DenseTensor& dout,
                                     const int input_axis,
                                     DenseTensor* dx) {
1384 1385
  auto* dx_data = dx->data<T>();

1386 1387 1388 1389 1390 1391
  int rank = out.dims().size();
  int axis = phi::funcs::CanonicalAxis(input_axis, rank);
  std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
  int N = tensor_dims[0];
  int dim = tensor_dims[1];
  int D = tensor_dims[2];
1392

1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443
  if (D == 1) {
    if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
      int dim_log2 = Log2Ceil(dim);
      int dim_ceil = 1 << dim_log2;
      int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
      int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

      constexpr int threads_per_block = 128;

      int warps_per_block = (threads_per_block / warp_size);
      int batches_per_block = warps_per_block * batches_per_warp;
      int blocks = (N + batches_per_block - 1) / batches_per_block;
      dim3 threads(warp_size, warps_per_block, 1);

      // vectorization read/write
      using T4 = typename VecT4<T>::Type;
      using T2 = typename VecT2<T>::Type;
      if (dim % 4 == 0) {
        SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
                                                  threads,
                                                  dev_ctx,
                                                  dx_data,
                                                  dout.data<T>(),
                                                  out.data<T>(),
                                                  N,
                                                  dim,
                                                  dim,
                                                  dim_log2);
      } else if (dim % 2 == 0) {
        SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
                                                  threads,
                                                  dev_ctx,
                                                  dx_data,
                                                  dout.data<T>(),
                                                  out.data<T>(),
                                                  N,
                                                  dim,
                                                  dim,
                                                  dim_log2);
      } else {
        SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
                                                 threads,
                                                 dev_ctx,
                                                 dx_data,
                                                 dout.data<T>(),
                                                 out.data<T>(),
                                                 N,
                                                 dim,
                                                 dim,
                                                 dim_log2);
      }
1444
    } else {
1445 1446
      LaunchSoftmaxBackwardCudnnKernel<T>(
          dev_ctx, out, dout, axis, LogMode, dx);
1447
    }
1448
  } else {
1449 1450
    LaunchNormalSoftmaxBackward<T, LogMode>(
        dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
1451 1452
  }
}
C
carryyu 已提交
1453 1454 1455 1456
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
#undef FIXED_VEC_SIZE_BASE
#undef FIXED_VEC_SIZE
1457

1458
}  // namespace phi