softmax_gpudnn.h 50.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/backends/gpu/gpu_info.h"
18 19
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
20 21
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
22
#include "paddle/phi/kernels/funcs/aligned_vector.h"
23 24 25 26
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

// See Note [ Why still include the fluid headers? ]
27 28
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
29

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#define MATRIX_SOFTMAX_ALIGN_BYTES 16
#define MATRIX_SOFTMAX_THREAHOLD 100000

#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

#define FIXED_VEC_SIZE_BASE(vec_size, ...) \
  case (vec_size): {                       \
    constexpr auto VecSize = (vec_size);   \
    __VA_ARGS__;                           \
  } break

#define FIXED_BLOCK_DIM(...)                \
  FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

#define FIXED_VEC_SIZE(...)              \
  FIXED_VEC_SIZE_BASE(8, ##__VA_ARGS__); \
  FIXED_VEC_SIZE_BASE(4, ##__VA_ARGS__)

56
namespace phi {
57

58 59
using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor;
using GPUDNNDataLayout = paddle::platform::DataLayout;
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

// Vectorization trait 4 * sizeof(T)
template <typename T>
class VecT4 {};
template <>
class VecT4<double> {
 public:
  using Type = long4;
};
template <>
class VecT4<float> {
 public:
  using Type = int4;
};
template <>
75
class VecT4<phi::dtype::float16> {
76 77 78
 public:
  using Type = int2;
};
79 80 81 82 83
template <>
class VecT4<phi::dtype::bfloat16> {
 public:
  using Type = int2;
};
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

// Vectorization trait 2 * sizeof(T)
template <typename T>
class VecT2 {};
template <>
class VecT2<double> {
 public:
  using Type = int4;
};
template <>
class VecT2<float> {
 public:
  using Type = int2;
};
template <>
99
class VecT2<phi::dtype::float16> {
100 101 102
 public:
  using Type = int;
};
103 104 105 106 107
template <>
class VecT2<phi::dtype::bfloat16> {
 public:
  using Type = int;
};
108

109
static inline int Log2Ceil(int value) {
110 111 112 113 114
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
}

115 116 117 118 119 120 121 122 123 124 125 126 127 128
inline int getBlockSize(int vec_size, uint64_t dim_size) {
  uint64_t block_size = 1;
  uint64_t max_block_size =
      std::min(dim_size / vec_size, static_cast<uint64_t>(1024));

  if (vec_size > 1) {
    max_block_size /= 2;
  }

  while (block_size < (max_block_size)) block_size *= 2;
  block_size = std::max(block_size, static_cast<uint64_t>(32));
  return block_size;
}

129 130 131 132 133 134
template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceSum(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
135 136
      T sum_val =
          paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
137 138 139 140 141 142 143 144 145 146 147
      sum[i] = sum[i] + sum_val;
    }
  }
}

template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceMax(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
148 149
      T max_val =
          paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
150 151 152 153 154
      sum[i] = max(sum[i], max_val);
    }
  }
}

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
template <typename T>
__inline__ __device__ void BlockReduceMax(T* val) {
  static __shared__ T shared[32];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  WarpReduceMax<T, 1, 32>(val);

  if (lane == 0) shared[wid] = *val;

  __syncthreads();

  int block_span = (blockDim.x + warpSize - 1) >> 5;
  *val = (lane < block_span) ? shared[lane] : -1e10f;
  WarpReduceMax<T, 1, 32>(val);
}

template <typename T>
__inline__ __device__ void BlockReduceSum(T* val) {
  static __shared__ T shared[32];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  WarpReduceSum<T, 1, 32>(val);

  __syncthreads();
  if (lane == 0) shared[wid] = *val;

  __syncthreads();

  int block_span = (blockDim.x + warpSize - 1) >> 5;
  *val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
  WarpReduceSum<T, 1, 32>(val);
}

190 191 192 193 194 195 196 197 198
template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
  inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }

  __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
    return max(a, b);
  }
};

199 200 201 202 203 204 205 206
template <typename T, typename AccT>
struct MaxFunctor {
  __device__ __forceinline__ AccT operator()(const AccT& max_v,
                                             const T& v) const {
    return max(max_v, static_cast<AccT>(v));
  }
};

207
template <typename Tx, typename Ty = Tx>
208
struct ExpFunctor {
209
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
210
    return static_cast<Ty>(std::exp(x));
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
  }
};

template <typename Tx, typename Ty = Tx>
struct ExpMulFunctor {
  HOSTDEVICE inline ExpMulFunctor() { y = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x) * y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnarySubFunctor {
  HOSTDEVICE inline UnarySubFunctor() { y = static_cast<Tx>(0.0f); }

  HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnaryLogFunctor {
  HOSTDEVICE inline UnaryLogFunctor() {}

  HOSTDEVICE explicit inline UnaryLogFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::log(x));
  }
};

template <typename Tx, typename Ty>
struct DataTransFunctor {
  HOSTDEVICE inline DataTransFunctor() {}

  HOSTDEVICE explicit inline DataTransFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return x == -std::numeric_limits<Tx>::infinity()
               ? -std::numeric_limits<Ty>::infinity()
               : static_cast<Ty>(x);
  }
};

template <typename Tx, typename Ty = Tx>
struct UnaryDivFunctor {
  HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x * n_inv);
  }

 private:
  Tx n_inv;
};

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
template <typename Tx, typename Ty = Tx>
struct SoftmaxForwardFunctor {
  HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum)
      : max(max), sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x - max) / sum);
  }

 private:
  Tx max;
  Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxBackwardFunctor {
  HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
    return static_cast<Ty>(out * (grad_out - sum));
  }

 private:
  Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxForwardFunctor {
  HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum)
      : max(max), log_sum(std::log(sum)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - max - log_sum);
  }

 private:
  Tx max;
  Tx log_sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxBackwardFunctor {
  HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
    return static_cast<Ty>(grad_out - std::exp(out) * sum);
  }

 private:
  Tx sum;
};

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
template <typename T, typename AccT>
struct SumExpFunctor {
  HOSTDEVICE inline SumExpFunctor(AccT v) : max_v(v) {}

  HOSTDEVICE inline AccT operator()(AccT sum, T v) const {
    return sum + std::exp(static_cast<AccT>(v) - max_v);
  }

 private:
  AccT max_v;
};

template <template <typename, typename> class Reduction,
          typename T,
          typename AccT,
          int VecSize>
__device__ __forceinline__ AccT
C
carryyu 已提交
349
ThreadVecReduce(T* data,
350
                int dim_size,
C
carryyu 已提交
351
                const int shift,
352 353 354 355
                const Reduction<T, AccT>& functor,
                AccT default_value) {
  using VecT = phi::AlignedVector<T, VecSize>;
  AccT thread_val = default_value;
C
carryyu 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368

  // for memory align, handle the unaligned data in first block.
  int offset = threadIdx.x;
  if (shift > 0) {
    data -= shift;
    dim_size += shift;
    if (offset >= shift) {
      thread_val = functor(thread_val, data[offset]);
    }
    dim_size -= blockDim.x;
    data += blockDim.x;
  }

369 370 371 372 373
  const int last = dim_size % (VecSize * blockDim.x);

  T v[VecSize];
  VecT* value = reinterpret_cast<VecT*>(&v);

C
carryyu 已提交
374 375
  for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
    *value = reinterpret_cast<VecT*>(data)[offset];
376 377 378 379 380 381
#pragma unroll
    for (int i = 0; i < VecSize; i++) {
      thread_val = functor(thread_val, v[i]);
    }
  }

C
carryyu 已提交
382 383
  offset = dim_size - last + threadIdx.x;
  for (; offset < dim_size; offset += blockDim.x) {
384 385 386 387 388 389 390 391 392
    thread_val = functor(thread_val, data[offset]);
  }
  return thread_val;
}

template <template <typename, typename> class Reduction,
          typename T,
          typename AccT,
          int VecSize>
C
carryyu 已提交
393 394 395 396 397
__device__ __forceinline__ void ThreadVecWriteVec(T* out,
                                                  T* input,
                                                  int dim_size,
                                                  const int shift,
                                                  Reduction<AccT, T> functor) {
398 399
  using VecT = phi::AlignedVector<T, VecSize>;

C
carryyu 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413
  // for memory align, handle the unaligned data in first block.
  int offset = threadIdx.x;
  if (shift > 0) {
    input -= shift;
    out -= shift;
    dim_size += shift;
    if (offset >= shift) {
      out[offset] = functor(static_cast<AccT>(input[offset]));
    }
    dim_size -= blockDim.x;
    input += blockDim.x;
    out += blockDim.x;
  }

414 415 416 417 418 419 420 421
  const int last = dim_size % (VecSize * blockDim.x);

  T in_v[VecSize];
  VecT* in_value = reinterpret_cast<VecT*>(&in_v);

  T out_v[VecSize];
  VecT* out_value = reinterpret_cast<VecT*>(&out_v);

C
carryyu 已提交
422 423
  for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
    *in_value = reinterpret_cast<VecT*>(input)[offset];
424 425 426 427 428 429 430
#pragma unroll
    for (int i = 0; i < VecSize; i++) {
      out_v[i] = functor(static_cast<AccT>(in_v[i]));
    }
    reinterpret_cast<VecT*>(out)[offset] = *out_value;
  }

C
carryyu 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
  offset = dim_size - last + threadIdx.x;
  // the tail
  for (; offset < dim_size; offset += blockDim.x) {
    out[offset] = functor(static_cast<AccT>(input[offset]));
  }
}

template <template <typename, typename> class Reduction,
          typename T,
          typename AccT,
          int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out,
                                               T* input,
                                               int dim_size,
                                               Reduction<AccT, T> functor) {
  const int last = dim_size % (VecSize * blockDim.x);

  for (int offset = threadIdx.x; offset < dim_size - last;
       offset += blockDim.x * VecSize) {
#pragma unroll
    for (int i = 0; i < VecSize; i++) {
      out[offset + i * blockDim.x] =
          functor(static_cast<AccT>(input[offset + i * blockDim.x]));
    }
  }

  // the tail
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
  for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
       offset += blockDim.x) {
    out[offset] = functor(static_cast<AccT>(input[offset]));
  }
}

template <typename T,
          typename AccT,
          typename IndexType,
          int BatchSize,
          int VecSize,
          bool LogMode = false>
__global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
  using VecT = phi::AlignedVector<T, VecSize>;

  int bid = blockIdx.x;
C
carryyu 已提交
474
  T* batch_input = const_cast<T*>(src) + bid * dim_size;
475 476
  T* batch_output = softmax + bid * dim_size;

C
carryyu 已提交
477 478 479 480 481
  const int input_align_shift =
      ((uint64_t)batch_input) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
  const int output_align_shift =
      ((uint64_t)batch_output) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);

482 483 484 485
  // get max value
  AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>(
      batch_input,
      dim_size,
C
carryyu 已提交
486
      input_align_shift,
487 488 489 490 491 492 493 494
      MaxFunctor<T, AccT>(),
      std::numeric_limits<AccT>::min());
  BlockReduceMax<AccT>(&thread_max);

  // get exp value and sum all
  AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>(
      batch_input,
      dim_size,
C
carryyu 已提交
495
      input_align_shift,
496 497 498 499 500 501 502 503
      SumExpFunctor<T, AccT>(thread_max),
      static_cast<AccT>(0.));
  BlockReduceSum<AccT>(&thread_exp);

  // write data to softmax_output according to the LogMode
  if (LogMode) {
    LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max,
                                                std::log(thread_exp));
C
carryyu 已提交
504 505 506 507 508 509 510
    if (input_align_shift == output_align_shift) {
      ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, input_align_shift, reduction);
    } else {
      ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, reduction);
    }
511 512
  } else {
    SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
C
carryyu 已提交
513 514 515 516 517 518 519
    if (input_align_shift == output_align_shift) {
      ThreadVecWriteVec<SoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, input_align_shift, reduction);
    } else {
      ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
          batch_output, batch_input, dim_size, reduction);
    }
520 521 522
  }
}

523 524 525 526 527 528 529 530 531 532
/*
Core function of computing softmax forward for axis=-1.
The computation includes
  - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
  - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  - Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
533 534 535
template <typename T,
          typename VecT,
          typename AccT,
536
          typename IndexType,
537
          int Log2Elements,
538
          bool LogMode = false>
539 540
__global__ void WarpSoftmaxForward(T* softmax,
                                   const T* src,
541 542 543 544 545 546 547 548 549 550 551 552 553 554
                                   const IndexType batch_size,
                                   const IndexType stride,
                                   const IndexType element_count) {
  constexpr IndexType kDimCeil = 1 << Log2Elements;
  constexpr IndexType kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr IndexType kVSize = sizeof(VecT) / sizeof(T);
  constexpr IndexType kLoops = kDimCeil / kWarpSize;
  constexpr IndexType kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
  constexpr IndexType kBatchSize = (kDimCeil <= 32) ? 2 : 1;
  IndexType first_batch =
      (static_cast<IndexType>(blockDim.y) * blockIdx.x + threadIdx.y) *
      kBatchSize;
  constexpr IndexType kStep = kBatchSize * kLoopsV * kVSize;
  constexpr IndexType kVItem = kLoopsV * kVSize;
555 556
  constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
  using kMode = kps::details::ReduceMode;
557 558

  // max index to read
559
  IndexType idx_max_v[kBatchSize];
560
#pragma unroll
561 562
  for (IndexType i = 0; i < kBatchSize; i++) {
    IndexType idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
563 564 565
    idx_max_v[i] = idx_max / kVSize;
  }

F
Feng Xing 已提交
566
  // data src
567 568 569 570 571 572 573 574
  // src_data: the raw data form global memory
  // sub_data: store the data obtained by (src_data - max), used by log_softmax
  // exp_data: store the data obtained by (exp(sub_data)), used by softmax
  T src_data[kBatchSize][kLoopsV][kVSize];
  AccT sub_data[kBatchSize][kLoopsV][kVSize];
  AccT exp_data[kBatchSize][kLoopsV][kVSize];
  kps::Init<AccT, kStep>(&sub_data[0][0][0], kLowInf);
  kps::Init<T, kStep>(&src_data[0][0][0], -std::numeric_limits<T>::infinity());
F
Feng Xing 已提交
575 576 577 578 579 580 581 582 583 584 585 586

  // data dst
  T out_tmp[kBatchSize][kLoopsV][kVSize];

  // max value
  AccT max[kBatchSize];
  kps::Init<AccT, kBatchSize>(&max[0], kLowInf);

  // sum value
  AccT sum[kBatchSize] = {0};

// read data from global memory
587
#pragma unroll
588
  for (IndexType i = 0; i < kBatchSize; ++i) {
F
Feng Xing 已提交
589 590
    const VecT* src_v =
        reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
591
    VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]);
592
    kps::ReadData<VecT, VecT, kLoopsV, 1, true>(
593
        &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
594
    kps::ElementwiseUnary<T, AccT, kVItem, 1, DataTransFunctor<T, AccT>>(
595
        &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor<T, AccT>());
596 597
  }

598
  // compute max
599 600 601 602 603
  kps::Reduce<AccT,
              kVItem,
              kBatchSize,
              ReduceMaxFunctor<AccT>,
              kMode::kLocalMode>(
604
      &max[0], &sub_data[0][0][0], ReduceMaxFunctor<AccT>(), true);
605
  WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
606

607 608
// compute sum
#pragma unroll
609
  for (IndexType i = 0; i < kBatchSize; ++i) {
610
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>(
611
        &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i]));
612
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpFunctor<AccT>>(
613
        &exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor<AccT>());
614
  }
615 616 617 618 619
  kps::Reduce<AccT,
              kVItem,
              kBatchSize,
              kps::AddFunctor<AccT>,
              kMode::kLocalMode>(
620
      &sum[0], &exp_data[0][0][0], kps::AddFunctor<AccT>(), true);
621 622
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

F
Feng Xing 已提交
623
// write data to global memory
624
#pragma unroll
625
  for (IndexType i = 0; i < kBatchSize; ++i) {
F
Feng Xing 已提交
626 627 628
    VecT* softmax_v =
        reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
629
    if (LogMode) {
630
      kps::ElementwiseUnary<AccT, T, kVItem, 1, UnarySubFunctor<AccT>>(
631
          &out_tmp[i][0][0],
632
          &sub_data[i][0][0],
633 634
          UnarySubFunctor<AccT>(std::log(sum[i])));
    } else {
635
      kps::ElementwiseUnary<AccT, T, kVItem, 1, UnaryDivFunctor<AccT>>(
636
          &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
637
    }
638
    kps::WriteData<VecT, VecT, kLoopsV, 1, true>(
639
        &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
640 641 642 643 644 645 646 647 648 649 650 651
  }
}

/*
Core function of computing softmax backward for axis=-1.
The computation includes
  - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
  - Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
652 653 654 655
template <typename T,
          typename VecT,
          typename AccT,
          int Log2Elements,
656
          bool LogMode = false>
657 658 659 660 661
__global__ void WarpSoftmaxBackward(T* dst,
                                    const T* grad,
                                    const T* src,
                                    int batch_size,
                                    int stride,
662 663 664 665
                                    int element_count) {
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
666
  constexpr int kLoops = kDimCeil / kWarpSize;
667
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
668
  constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
669 670
  int element_count_v = element_count / kVSize;
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
671 672 673 674 675 676 677 678
  int local_batches = min(batch_size - first_batch, kBatchSize);

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
679 680 681
  }

  // read data from global memory
682 683 684 685 686 687 688 689
  VecT src_reg[kBatchSize][kLoopsV];
  VecT grad_reg[kBatchSize][kLoopsV];
  VecT k_value;
  for (int s = 0; s < kVSize; s++) {
    reinterpret_cast<T*>(&k_value)[s] = 0.0;
  }
  kps::Init<VecT, kBatchSize * kLoopsV>(&src_reg[0][0], k_value);
  kps::Init<VecT, kBatchSize * kLoopsV>(&grad_reg[0][0], k_value);
690
#pragma unroll
691 692 693 694 695
  for (int i = 0; i < kBatchSize; ++i) {
    int flag = i < local_batches ? 1 : 0;
    int ptr = (first_batch + i) * stride;
    const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
    const VecT* grad_v = reinterpret_cast<const VecT*>(&grad[ptr]);
696
    kps::ReadData<VecT, VecT, kLoopsV, 1, true>(
697
        &src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag);
698
    kps::ReadData<VecT, VecT, kLoopsV, 1, true>(
699
        &grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag);
700 701
  }

702 703 704 705 706 707 708
  // change T to AccT
  AccT src_tmp[kBatchSize][kLoopsV][kVSize];
  AccT grad_tmp[kBatchSize][kLoopsV][kVSize];
  const T* src_ptr = reinterpret_cast<const T*>(&src_reg[0][0]);
  const T* grad_ptr = reinterpret_cast<const T*>(&grad_reg[0][0]);
  constexpr int kStep = kBatchSize * kLoopsV * kVSize;
  constexpr int kVItem = kLoopsV * kVSize;
709
  kps::ElementwiseUnary<T, AccT, kStep, 1, DataTransFunctor<T, AccT>>(
710
      &src_tmp[0][0][0], &src_ptr[0], DataTransFunctor<T, AccT>());
711
  kps::ElementwiseUnary<T, AccT, kStep, 1, DataTransFunctor<T, AccT>>(
712 713
      &grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor<T, AccT>());

714 715
  // compute sum
  AccT sum[kBatchSize]{0.0};
716 717 718
  AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
  AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
  AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
719 720 721 722 723 724 725 726
  if (LogMode) {
    kps::Reduce<AccT,
                kVItem,
                kBatchSize,
                kps::AddFunctor<AccT>,
                kps::details::ReduceMode::kLocalMode>(
        &sum[0], &grad_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
  } else {
727
    kps::ElementwiseBinary<AccT, AccT, kStep, 1, kps::MulFunctor<AccT>>(
728 729 730 731 732 733 734 735
        &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
    kps::Reduce<AccT,
                kVItem,
                kBatchSize,
                kps::AddFunctor<AccT>,
                kps::details::ReduceMode::kLocalMode>(
        &sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
  }
736 737
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

738 739 740
  // write result to global memory
  AccT out[kBatchSize][kLoopsV][kVSize];
  T out_tmp[kBatchSize][kLoopsV][kVSize];
741 742 743
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    if (i >= local_batches) break;
744 745
    AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
    AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
746
    if (LogMode) {
747
      kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpMulFunctor<AccT>>(
748
          &out[i][0][0], &srcptr[0], ExpMulFunctor<AccT>(sum[i]));
749
      kps::ElementwiseBinary<AccT, T, kVItem, 1, kps::SubFunctor<AccT>>(
750 751 752 753 754
          &out_tmp[i][0][0],
          &gradptr[0],
          &out[i][0][0],
          kps::SubFunctor<AccT>());
    } else {
755
      kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>(
756
          &out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
757
      kps::ElementwiseBinary<AccT, T, kVItem, 1, kps::MulFunctor<AccT>>(
758 759 760 761 762
          &out_tmp[i][0][0],
          &srcptr[0],
          &out[i][0][0],
          kps::MulFunctor<AccT>());
    }
763
    VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
764
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
765
    kps::WriteData<VecT, VecT, kLoopsV, 1, true>(
766
        &dst_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
767 768 769
  }
}

770 771 772 773 774
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT)                   \
  case Log2Elements:                                                    \
    WarpSoftmaxForward<T, VecT, AccT, IndexType, Log2Elements, LogMode> \
        <<<blocks, threads, 0, dev_ctx.stream()>>>(                     \
            dst, src, batch_size, stride, element_count);               \
775 776 777 778 779
    break;

/*
  Wrapper of softmax formward with template instantiation on size of input.
*/
780 781
template <typename T, typename VecT, typename IndexType, bool LogMode>
void SwitchWarpSoftmaxForward(const IndexType blocks,
782 783 784 785
                              const dim3 threads,
                              const GPUContext& dev_ctx,
                              T* dst,
                              const T* src,
786 787 788 789
                              const IndexType batch_size,
                              const IndexType stride,
                              const IndexType element_count,
                              IndexType Log2Elements) {
790
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
  switch (Log2Elements) {
    SOFTMAX_WARP_FORWARD_CASE(0, AccT);
    SOFTMAX_WARP_FORWARD_CASE(1, AccT);
    SOFTMAX_WARP_FORWARD_CASE(2, AccT);
    SOFTMAX_WARP_FORWARD_CASE(3, AccT);
    SOFTMAX_WARP_FORWARD_CASE(4, AccT);
    SOFTMAX_WARP_FORWARD_CASE(5, AccT);
    SOFTMAX_WARP_FORWARD_CASE(6, AccT);
    SOFTMAX_WARP_FORWARD_CASE(7, AccT);
    SOFTMAX_WARP_FORWARD_CASE(8, AccT);
    SOFTMAX_WARP_FORWARD_CASE(9, AccT);
    default:
      break;
  }
}

807 808 809 810 811
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT)          \
  case Log2Elements:                                            \
    WarpSoftmaxBackward<T, VecT, AccT, Log2Elements, LogMode>   \
        <<<blocks, threads, 0, dev_ctx.stream()>>>(             \
            dst, grad, src, batch_size, stride, element_count); \
812 813 814 815 816 817
    break;

/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
818 819 820 821 822 823 824 825 826 827 828
void SwitchWarpSoftmaxBackward(const int blocks,
                               const dim3 threads,
                               const GPUContext& dev_ctx,
                               T* dst,
                               const T* grad,
                               const T* src,
                               const int batch_size,
                               const int stride,
                               const int element_count,
                               int Log2Elements) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
  switch (Log2Elements) {
    SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(3, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(4, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(5, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(6, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
    default:
      break;
  }
}

#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

848 849 850 851 852
/**
 * <NormalSoftmaxKernel>
 * Better performence when axis != -1
 */

853 854 855 856
static void GetGridDim(
    int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) {
  int device_id = phi::backends::gpu::GetCurrentDeviceId();
  int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id);
857
  int max_threads_per_mp =
858
      phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id);
859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876
  int max_threads = max_threads_per_mp * max_mp;
  int num_threads = block.x * block.y;
  int max_num_blocks = max_threads / num_threads;

  int grid_x = (low_dim + block.x - 1) / block.x;
  grid_x = std::min(grid_x, max_num_blocks);
  int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
  grid_y = std::min(grid_y, high_dim);
  grid->x = grid_x;
  grid->y = grid_y;
}

static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
#ifdef __HIPCC__
  constexpr int max_num_threads = 256;
#else
  constexpr int max_num_threads = 1024;
#endif
877 878
  int block_x = 1 << Log2Ceil(low_dim);
  int block_y = 1 << Log2Ceil(mid_dim);
879 880 881 882 883
  block->x = std::min(block_x, 32);
  block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
  block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
}

884 885
static void GetLaunchConfig(
    int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) {
886 887 888 889
  GetBlockDim(mid_dim, low_dim, block);
  GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
}

890 891
template <typename T,
          typename AccT,
892 893
          template <typename, typename>
          class Functor>
894 895
__global__ void NormalSoftmaxForward(
    T* output, const T* input, int high_dim, int mid_dim, int low_dim) {
896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912
  using kMode = kps::details::ReduceMode;
  const int high_stride = mid_dim * low_dim;
  const int mid_stride = low_dim;
  for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
    for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
         low_id += blockDim.x * gridDim.x) {
      const int input_offset = high_id * high_stride + low_id;

      // 1. reduce max
      AccT max_value = -std::numeric_limits<AccT>::infinity();
      AccT value = -std::numeric_limits<AccT>::infinity();
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
        max_value = kps::MaxFunctor<AccT>()(max_value, value);
      }

      if (blockDim.y > 1) {
913
        kps::Reduce<AccT, 1, 1, kps::MaxFunctor<AccT>, kMode::kGlobalMode>(
914 915 916 917 918 919 920 921 922 923
            &max_value, &max_value, kps::MaxFunctor<AccT>(), false);
      }

      // 2. reduce sum
      AccT sum = 0;
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
        sum += std::exp(value - max_value);
      }
      if (blockDim.y > 1) {
924
        kps::Reduce<AccT, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
925 926 927 928 929 930 931 932 933 934 935 936 937
            &sum, &sum, kps::AddFunctor<AccT>(), false);
      }

      // 3. (log)softmax
      Functor<AccT, T> functor(max_value, sum);
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        int data_offset = input_offset + mid_id * mid_stride;
        output[data_offset] = functor(static_cast<AccT>(input[data_offset]));
      }
    }
  }
}

938 939
template <typename T,
          typename AccT,
940 941
          template <typename, typename>
          class Functor,
942
          bool LogMode>
943 944 945 946 947 948
__global__ void NormalSoftmaxBackward(T* input_grad,
                                      const T* output_grad,
                                      const T* output,
                                      int high_dim,
                                      int mid_dim,
                                      int low_dim) {
949 950 951 952 953 954 955 956 957 958
  using kMode = kps::details::ReduceMode;
  const int high_stride = mid_dim * low_dim;
  const int mid_stride = low_dim;
  for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
    for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
         low_id += blockDim.x * gridDim.x) {
      const int grad_offset = high_id * high_stride + low_id;

      // 1. reduce sum
      AccT sum = 0;
959 960 961 962 963 964 965 966 967 968 969
      if (LogMode) {
        for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
          int data_offset = grad_offset + mid_id * mid_stride;
          sum += static_cast<AccT>(output_grad[data_offset]);
        }
      } else {
        for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
          int data_offset = grad_offset + mid_id * mid_stride;
          sum += static_cast<AccT>(output_grad[data_offset]) *
                 static_cast<AccT>(output[data_offset]);
        }
970 971
      }
      if (blockDim.y > 1) {
972
        kps::Reduce<AccT, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987
            &sum, &sum, kps::AddFunctor<AccT>(), false);
      }

      // 2. (log)softmax backward
      Functor<AccT, T> functor(sum);
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        int data_offset = grad_offset + mid_id * mid_stride;
        input_grad[data_offset] =
            functor(static_cast<AccT>(output_grad[data_offset]),
                    static_cast<AccT>(output[data_offset]));
      }
    }
  }
}

988
template <typename T, bool LogMode = false>
989 990 991 992 993 994 995
void LaunchNormalSoftmaxForward(const GPUContext& dev_ctx,
                                T* output_data,
                                const T* input_data,
                                int high_dim,
                                int mid_dim,
                                int low_dim) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
996 997 998
  dim3 grid, block;
  GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
  if (LogMode) {
999 1000 1001
    NormalSoftmaxForward<T, AccT, LogSoftmaxForwardFunctor>
        <<<grid, block, 0, dev_ctx.stream()>>>(
            output_data, input_data, high_dim, mid_dim, low_dim);
1002
  } else {
1003 1004 1005
    NormalSoftmaxForward<T, AccT, SoftmaxForwardFunctor>
        <<<grid, block, 0, dev_ctx.stream()>>>(
            output_data, input_data, high_dim, mid_dim, low_dim);
1006 1007 1008
  }
}

1009
template <typename T, bool LogMode = false>
1010 1011 1012 1013 1014 1015 1016 1017
void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
                                 T* input_grad_data,
                                 const T* output_grad_data,
                                 const T* output_data,
                                 int high_dim,
                                 int mid_dim,
                                 int low_dim) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
1018 1019 1020
  dim3 grid, block;
  GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
  if (LogMode) {
1021 1022 1023 1024 1025 1026 1027
    NormalSoftmaxBackward<T, AccT, LogSoftmaxBackwardFunctor, LogMode>
        <<<grid, block, 0, dev_ctx.stream()>>>(input_grad_data,
                                               output_grad_data,
                                               output_data,
                                               high_dim,
                                               mid_dim,
                                               low_dim);
1028
  } else {
1029 1030 1031 1032 1033 1034 1035
    NormalSoftmaxBackward<T, AccT, SoftmaxBackwardFunctor, LogMode>
        <<<grid, block, 0, dev_ctx.stream()>>>(input_grad_data,
                                               output_grad_data,
                                               output_data,
                                               high_dim,
                                               mid_dim,
                                               low_dim);
1036 1037 1038
  }
}

1039 1040 1041 1042 1043 1044
template <typename T = int>
static std::vector<T> GetSoftmaxTensorDims(const phi::DDim& dims,
                                           const int axis) {
  auto dim = static_cast<T>(dims[axis]);
  auto N = phi::funcs::SizeToAxis<T>(axis, dims);
  auto D = phi::funcs::SizeOutAxis<T>(axis, dims);
1045 1046 1047 1048 1049
  return {N, dim, D, 1};
}

template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
1050
                               const T* x_data,
1051
                               const int axis,
1052
                               const int rank,
1053
                               const bool log_mode,
1054 1055
                               const std::vector<int>& tensor_dims,
                               T* out_data) {
1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069
  auto handle = dev_ctx.cudnn_handle();
  GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

  ScopedTensorDescriptor scoped_desc;
#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t desc =
      scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                               : MIOPEN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::miopenSoftmaxForward_V2(
      handle,
      paddle::platform::CudnnDataType<T>::kOne(),
      desc,
1070
      x_data,
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
      paddle::platform::CudnnDataType<T>::kZero(),
      desc,
      out_data,
      algo,
      mode));
#else
  cudnnTensorDescriptor_t desc = scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                               : CUDNN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxForward(
      handle,
      algo,
      mode,
      paddle::platform::CudnnDataType<T>::kOne(),
      desc,
1087
      x_data,
1088 1089 1090 1091 1092 1093
      paddle::platform::CudnnDataType<T>::kZero(),
      desc,
      out_data));
#endif
}

1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
template <typename T>
void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
                                     const DenseTensor& x,
                                     const int axis,
                                     const bool log_mode,
                                     DenseTensor* out) {
  auto* out_data = out->data<T>();
  auto* x_data = x.data<T>();
  const int rank = x.dims().size();

  std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
  int64_t remaining = tensor_dims[0];
  int dim = tensor_dims[1];
  int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
  int offset = batch_size * dim;
  while (remaining > 0) {
    tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
    SoftmaxForwardCudnnKernel<T>(
        dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
    x_data += offset;
    out_data += offset;
    remaining -= batch_size;
  }
}

1119 1120
template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
1121 1122
                                const T* out_data,
                                const T* dout_data,
1123
                                const int axis,
1124
                                const int rank,
1125
                                const bool log_mode,
1126 1127
                                const std::vector<int>& tensor_dims,
                                T* dx_data) {
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142
  auto handle = dev_ctx.cudnn_handle();
  GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

  ScopedTensorDescriptor scoped_desc;
#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t desc =
      scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                               : MIOPEN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::miopenSoftmaxBackward_V2(
          handle,
          paddle::platform::CudnnDataType<T>::kOne(),
          desc,
1143
          out_data,
1144
          desc,
1145
          dout_data,
1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
          paddle::platform::CudnnDataType<T>::kZero(),
          desc,
          dx_data,
          algo,
          mode));
#else
  cudnnTensorDescriptor_t desc = scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                               : CUDNN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxBackward(
      handle,
      algo,
      mode,
      paddle::platform::CudnnDataType<T>::kOne(),
      desc,
1162
      out_data,
1163
      desc,
1164
      dout_data,
1165 1166 1167 1168 1169 1170
      paddle::platform::CudnnDataType<T>::kZero(),
      desc,
      dx_data));
#endif
}

1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
template <typename T>
void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
                                      const DenseTensor& out,
                                      const DenseTensor& dout,
                                      const int axis,
                                      const bool log_mode,
                                      DenseTensor* dx) {
  auto* dx_data = dx->data<T>();
  auto* out_data = out.data<T>();
  auto* dout_data = dout.data<T>();
  int rank = out.dims().size();

  std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
  int64_t remaining = tensor_dims[0];
  int dim = tensor_dims[1];
  int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
  int offset = batch_size * dim;
  while (remaining > 0) {
    tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
    SoftmaxBackwardCudnnKernel<T>(dev_ctx,
                                  out_data,
                                  dout_data,
                                  axis,
                                  rank,
                                  log_mode,
                                  tensor_dims,
                                  dx_data);
    out_data += offset;
    dout_data += offset;
    dx_data += offset;
    remaining -= batch_size;
  }
}

1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228
template <typename T, typename IndexType, bool LogMode>
void LaunchKeMatrixSoftmaxForwardKernel(
    const GPUContext& dev_ctx, T* out, const T* input, int N, int dim_size) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
  const int vec_size = MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
  switch (getBlockSize(vec_size, dim_size)) {
    FIXED_BLOCK_DIM(switch (vec_size) {
      FIXED_VEC_SIZE(
          KeMatrixSoftmaxForward<T,
                                 AccT,
                                 IndexType,
                                 kBlockDim,
                                 VecSize,
                                 LogMode>
          <<<N, kBlockDim, 0, dev_ctx.stream()>>>(out, input, dim_size));
      default:
        break;
    });
    default:
      PADDLE_THROW(
          errors::Fatal("the input dim has error in the softmax cuda kernel."));
  }
}

1229 1230
#if CUDNN_VERSION < 8100
template <>
1231
inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241
    const GPUContext& dev_ctx,
    const DenseTensor& x,
    const int axis,
    const bool log_mode,
    DenseTensor* out) {
  PADDLE_THROW(errors::Unavailable(
      "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
      "8100."));
}
template <>
1242
inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254
    const GPUContext& dev_ctx,
    const DenseTensor& out,
    const DenseTensor& dout,
    const int axis,
    const bool log_mode,
    DenseTensor* dx) {
  PADDLE_THROW(errors::Unavailable(
      "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
      "8100."));
}
#endif

1255
template <typename T>
1256 1257 1258
bool UseCudnnSoftmax(const GPUContext& ctx,
                     int64_t softmax_dim,
                     bool last_dim) {
1259 1260 1261 1262 1263 1264 1265 1266 1267 1268
  bool cudnn_available = ctx.cudnn_handle();
  if (!ctx.cudnn_handle()) {
    if (std::is_same<T, phi::dtype::bfloat16>::value) {
#if CUDNN_VERSION < 8100
      cudnn_available = false;
#endif
    }
  }
  constexpr int max_dim = 512;
  if (!cudnn_available || !last_dim ||
1269 1270
      (softmax_dim <= max_dim && sizeof(T) <= 4) ||
      softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) {
1271 1272 1273 1274 1275 1276
    return false;
  } else {
    return true;
  }
}

1277 1278 1279 1280 1281
template <typename T, typename IndexType, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
                                        const DenseTensor& x,
                                        const int input_axis,
                                        DenseTensor* out) {
1282 1283
  auto* out_data = out->data<T>();

1284 1285
  int rank = x.dims().size();
  int axis = phi::funcs::CanonicalAxis(input_axis, rank);
1286 1287 1288 1289
  std::vector<IndexType> tensor_dims =
      GetSoftmaxTensorDims<IndexType>(x.dims(), axis);
  IndexType N = tensor_dims[0];
  IndexType dim = tensor_dims[1];
1290
  int D = tensor_dims[2];
1291

1292 1293
  if (D == 1) {
    if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
1294 1295 1296 1297 1298
      if (dim >= MATRIX_SOFTMAX_THREAHOLD) {
        LaunchKeMatrixSoftmaxForwardKernel<T, IndexType, LogMode>(
            dev_ctx, out_data, x.data<T>(), N, dim);
        return;
      }
1299
      int dim_log2 = static_cast<int>(Log2Ceil(dim));
1300
      IndexType dim_ceil = 1 << dim_log2;
1301 1302 1303 1304 1305 1306 1307 1308
      int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
      int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;

      // use 128 threads per block to maximimize gpu utilization
      constexpr int threads_per_block = 128;

      int warps_per_block = (threads_per_block / warp_size);
      int batches_per_block = warps_per_block * batches_per_warp;
1309
      IndexType blocks = (N + batches_per_block - 1) / batches_per_block;
1310 1311 1312 1313 1314 1315 1316
      dim3 threads(warp_size, warps_per_block, 1);

      // vectorization read/write
      using T4 = typename VecT4<T>::Type;
      using T2 = typename VecT2<T>::Type;

      if (dim % 4 == 0) {
1317 1318 1319 1320 1321 1322 1323 1324 1325
        SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
                                                            threads,
                                                            dev_ctx,
                                                            out_data,
                                                            x.data<T>(),
                                                            N,
                                                            dim,
                                                            dim,
                                                            dim_log2);
1326
      } else if (dim % 2 == 0) {
1327 1328 1329 1330 1331 1332 1333 1334 1335
        SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
                                                            threads,
                                                            dev_ctx,
                                                            out_data,
                                                            x.data<T>(),
                                                            N,
                                                            dim,
                                                            dim,
                                                            dim_log2);
1336
      } else {
1337 1338 1339 1340 1341 1342 1343 1344 1345
        SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
                                                           threads,
                                                           dev_ctx,
                                                           out_data,
                                                           x.data<T>(),
                                                           N,
                                                           dim,
                                                           dim,
                                                           dim_log2);
1346
      }
1347
    } else {
1348
      LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
1349
    }
1350
  } else {
1351 1352
    LaunchNormalSoftmaxForward<T, LogMode>(
        dev_ctx, out_data, x.data<T>(), N, dim, D);
1353 1354 1355
  }
}

1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
                                    const DenseTensor& x,
                                    const int input_axis,
                                    DenseTensor* out) {
  if (x.numel() >= std::numeric_limits<int32_t>::max()) {
    SoftmaxForwardCUDAKernelDriverImpl<T, int64_t, LogMode>(
        dev_ctx, x, input_axis, out);
  } else {
    SoftmaxForwardCUDAKernelDriverImpl<T, int32_t, LogMode>(
        dev_ctx, x, input_axis, out);
  }
}

1370
template <typename T, bool LogMode = false>
1371 1372 1373 1374 1375
void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
                                     const DenseTensor& out,
                                     const DenseTensor& dout,
                                     const int input_axis,
                                     DenseTensor* dx) {
1376 1377
  auto* dx_data = dx->data<T>();

1378 1379 1380 1381 1382 1383
  int rank = out.dims().size();
  int axis = phi::funcs::CanonicalAxis(input_axis, rank);
  std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
  int N = tensor_dims[0];
  int dim = tensor_dims[1];
  int D = tensor_dims[2];
1384

1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435
  if (D == 1) {
    if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
      int dim_log2 = Log2Ceil(dim);
      int dim_ceil = 1 << dim_log2;
      int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
      int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

      constexpr int threads_per_block = 128;

      int warps_per_block = (threads_per_block / warp_size);
      int batches_per_block = warps_per_block * batches_per_warp;
      int blocks = (N + batches_per_block - 1) / batches_per_block;
      dim3 threads(warp_size, warps_per_block, 1);

      // vectorization read/write
      using T4 = typename VecT4<T>::Type;
      using T2 = typename VecT2<T>::Type;
      if (dim % 4 == 0) {
        SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
                                                  threads,
                                                  dev_ctx,
                                                  dx_data,
                                                  dout.data<T>(),
                                                  out.data<T>(),
                                                  N,
                                                  dim,
                                                  dim,
                                                  dim_log2);
      } else if (dim % 2 == 0) {
        SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
                                                  threads,
                                                  dev_ctx,
                                                  dx_data,
                                                  dout.data<T>(),
                                                  out.data<T>(),
                                                  N,
                                                  dim,
                                                  dim,
                                                  dim_log2);
      } else {
        SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
                                                 threads,
                                                 dev_ctx,
                                                 dx_data,
                                                 dout.data<T>(),
                                                 out.data<T>(),
                                                 N,
                                                 dim,
                                                 dim,
                                                 dim_log2);
      }
1436
    } else {
1437 1438
      LaunchSoftmaxBackwardCudnnKernel<T>(
          dev_ctx, out, dout, axis, LogMode, dx);
1439
    }
1440
  } else {
1441 1442
    LaunchNormalSoftmaxBackward<T, LogMode>(
        dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
1443 1444
  }
}
C
carryyu 已提交
1445 1446 1447 1448
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
#undef FIXED_VEC_SIZE_BASE
#undef FIXED_VEC_SIZE
1449

1450
}  // namespace phi