softmax_gpudnn.h 37.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/backends/gpu/gpu_info.h"
18 19
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
20 21 22 23 24 25
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

// See Note [ Why still include the fluid headers? ]
26 27
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
28

29
namespace phi {
30

31 32
using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor;
using GPUDNNDataLayout = paddle::platform::DataLayout;
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

// Vectorization trait 4 * sizeof(T)
template <typename T>
class VecT4 {};
template <>
class VecT4<double> {
 public:
  using Type = long4;
};
template <>
class VecT4<float> {
 public:
  using Type = int4;
};
template <>
48
class VecT4<phi::dtype::float16> {
49 50 51
 public:
  using Type = int2;
};
52 53 54 55 56
template <>
class VecT4<phi::dtype::bfloat16> {
 public:
  using Type = int2;
};
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

// Vectorization trait 2 * sizeof(T)
template <typename T>
class VecT2 {};
template <>
class VecT2<double> {
 public:
  using Type = int4;
};
template <>
class VecT2<float> {
 public:
  using Type = int2;
};
template <>
72
class VecT2<phi::dtype::float16> {
73 74 75
 public:
  using Type = int;
};
76 77 78 79 80
template <>
class VecT2<phi::dtype::bfloat16> {
 public:
  using Type = int;
};
81

82
static inline int Log2Ceil(int value) {
83 84 85 86 87 88 89 90 91 92 93
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
}

template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceSum(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
94 95
      T sum_val =
          paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
96 97 98 99 100 101 102 103 104 105 106
      sum[i] = sum[i] + sum_val;
    }
  }
}

template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceMax(T* sum) {
#pragma unroll
  for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < BatchSize; ++i) {
107 108
      T max_val =
          paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
109 110 111 112 113
      sum[i] = max(sum[i], max_val);
    }
  }
}

114 115 116 117 118 119 120 121 122 123
template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
  inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }

  __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
    return max(a, b);
  }
};

template <typename Tx, typename Ty = Tx>
124
struct ExpFunctor {
125
  HOSTDEVICE inline Ty operator()(const Tx& x) const {
126
    return static_cast<Ty>(std::exp(x));
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
  }
};

template <typename Tx, typename Ty = Tx>
struct ExpMulFunctor {
  HOSTDEVICE inline ExpMulFunctor() { y = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x) * y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnarySubFunctor {
  HOSTDEVICE inline UnarySubFunctor() { y = static_cast<Tx>(0.0f); }

  HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - y);
  }

 private:
  Tx y;
};

template <typename Tx, typename Ty = Tx>
struct UnaryLogFunctor {
  HOSTDEVICE inline UnaryLogFunctor() {}

  HOSTDEVICE explicit inline UnaryLogFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::log(x));
  }
};

template <typename Tx, typename Ty>
struct DataTransFunctor {
  HOSTDEVICE inline DataTransFunctor() {}

  HOSTDEVICE explicit inline DataTransFunctor(int n) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return x == -std::numeric_limits<Tx>::infinity()
               ? -std::numeric_limits<Ty>::infinity()
               : static_cast<Ty>(x);
  }
};

template <typename Tx, typename Ty = Tx>
struct UnaryDivFunctor {
  HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast<Tx>(1.0f); }

  HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x * n_inv);
  }

 private:
  Tx n_inv;
};

196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
template <typename Tx, typename Ty = Tx>
struct SoftmaxForwardFunctor {
  HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum)
      : max(max), sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(std::exp(x - max) / sum);
  }

 private:
  Tx max;
  Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxBackwardFunctor {
  HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
    return static_cast<Ty>(out * (grad_out - sum));
  }

 private:
  Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxForwardFunctor {
  HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum)
      : max(max), log_sum(std::log(sum)) {}

  HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x - max - log_sum);
  }

 private:
  Tx max;
  Tx log_sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxBackwardFunctor {
  HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

  HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
    return static_cast<Ty>(grad_out - std::exp(out) * sum);
  }

 private:
  Tx sum;
};

248 249 250 251 252 253 254 255 256 257
/*
Core function of computing softmax forward for axis=-1.
The computation includes
  - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
  - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  - Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
258 259 260 261
template <typename T,
          typename VecT,
          typename AccT,
          int Log2Elements,
262
          bool LogMode = false>
263 264 265 266
__global__ void WarpSoftmaxForward(T* softmax,
                                   const T* src,
                                   const int batch_size,
                                   const int stride,
267 268 269 270
                                   const int element_count) {
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
271 272
  constexpr int kLoops = kDimCeil / kWarpSize;
  constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
273 274
  constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
275 276 277 278
  constexpr int kStep = kBatchSize * kLoopsV * kVSize;
  constexpr int kVItem = kLoopsV * kVSize;
  constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
  using kMode = kps::details::ReduceMode;
279 280 281 282 283 284 285 286 287

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
  }

F
Feng Xing 已提交
288
  // data src
289 290 291 292 293 294 295 296
  // src_data: the raw data form global memory
  // sub_data: store the data obtained by (src_data - max), used by log_softmax
  // exp_data: store the data obtained by (exp(sub_data)), used by softmax
  T src_data[kBatchSize][kLoopsV][kVSize];
  AccT sub_data[kBatchSize][kLoopsV][kVSize];
  AccT exp_data[kBatchSize][kLoopsV][kVSize];
  kps::Init<AccT, kStep>(&sub_data[0][0][0], kLowInf);
  kps::Init<T, kStep>(&src_data[0][0][0], -std::numeric_limits<T>::infinity());
F
Feng Xing 已提交
297 298 299 300 301 302 303 304 305 306 307 308

  // data dst
  T out_tmp[kBatchSize][kLoopsV][kVSize];

  // max value
  AccT max[kBatchSize];
  kps::Init<AccT, kBatchSize>(&max[0], kLowInf);

  // sum value
  AccT sum[kBatchSize] = {0};

// read data from global memory
309 310
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
F
Feng Xing 已提交
311 312
    const VecT* src_v =
        reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
313
    VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]);
314 315 316
    kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
        &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
    kps::ElementwiseUnary<T, AccT, kVItem, 1, 1, DataTransFunctor<T, AccT>>(
317
        &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor<T, AccT>());
318 319
  }

320
  // compute max
321 322 323 324 325 326
  kps::Reduce<AccT,
              kVItem,
              kBatchSize,
              1,
              ReduceMaxFunctor<AccT>,
              kMode::kLocalMode>(
327
      &max[0], &sub_data[0][0][0], ReduceMaxFunctor<AccT>(), true);
328
  WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
329

330 331
// compute sum
#pragma unroll
332
  for (int i = 0; i < kBatchSize; ++i) {
333 334 335 336
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
        &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i]));
    kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpFunctor<AccT>>(
        &exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor<AccT>());
337
  }
338 339 340 341 342 343
  kps::Reduce<AccT,
              kVItem,
              kBatchSize,
              1,
              kps::AddFunctor<AccT>,
              kMode::kLocalMode>(
344
      &sum[0], &exp_data[0][0][0], kps::AddFunctor<AccT>(), true);
345 346
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

F
Feng Xing 已提交
347
// write data to global memory
348 349
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
F
Feng Xing 已提交
350 351 352
    VecT* softmax_v =
        reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
353 354 355
    if (LogMode) {
      kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnarySubFunctor<AccT>>(
          &out_tmp[i][0][0],
356
          &sub_data[i][0][0],
357 358 359
          UnarySubFunctor<AccT>(std::log(sum[i])));
    } else {
      kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
360
          &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
361
    }
362 363
    kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
        &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
364 365 366 367 368 369 370 371 372 373 374 375
  }
}

/*
Core function of computing softmax backward for axis=-1.
The computation includes
  - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
  - Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
376 377 378 379
template <typename T,
          typename VecT,
          typename AccT,
          int Log2Elements,
380
          bool LogMode = false>
381 382 383 384 385
__global__ void WarpSoftmaxBackward(T* dst,
                                    const T* grad,
                                    const T* src,
                                    int batch_size,
                                    int stride,
386 387 388 389
                                    int element_count) {
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
390
  constexpr int kLoops = kDimCeil / kWarpSize;
391
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
392
  constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
393 394
  int element_count_v = element_count / kVSize;
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
395 396 397 398 399 400 401 402
  int local_batches = min(batch_size - first_batch, kBatchSize);

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
403 404 405
  }

  // read data from global memory
406 407 408 409 410 411 412 413
  VecT src_reg[kBatchSize][kLoopsV];
  VecT grad_reg[kBatchSize][kLoopsV];
  VecT k_value;
  for (int s = 0; s < kVSize; s++) {
    reinterpret_cast<T*>(&k_value)[s] = 0.0;
  }
  kps::Init<VecT, kBatchSize * kLoopsV>(&src_reg[0][0], k_value);
  kps::Init<VecT, kBatchSize * kLoopsV>(&grad_reg[0][0], k_value);
414
#pragma unroll
415 416 417 418 419 420 421 422 423
  for (int i = 0; i < kBatchSize; ++i) {
    int flag = i < local_batches ? 1 : 0;
    int ptr = (first_batch + i) * stride;
    const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
    const VecT* grad_v = reinterpret_cast<const VecT*>(&grad[ptr]);
    kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
        &src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag);
    kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
        &grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag);
424 425
  }

426 427 428 429 430 431 432 433 434 435 436 437
  // change T to AccT
  AccT src_tmp[kBatchSize][kLoopsV][kVSize];
  AccT grad_tmp[kBatchSize][kLoopsV][kVSize];
  const T* src_ptr = reinterpret_cast<const T*>(&src_reg[0][0]);
  const T* grad_ptr = reinterpret_cast<const T*>(&grad_reg[0][0]);
  constexpr int kStep = kBatchSize * kLoopsV * kVSize;
  constexpr int kVItem = kLoopsV * kVSize;
  kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
      &src_tmp[0][0][0], &src_ptr[0], DataTransFunctor<T, AccT>());
  kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
      &grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor<T, AccT>());

438 439
  // compute sum
  AccT sum[kBatchSize]{0.0};
440 441 442
  AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
  AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
  AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
  if (LogMode) {
    kps::Reduce<AccT,
                kVItem,
                kBatchSize,
                1,
                kps::AddFunctor<AccT>,
                kps::details::ReduceMode::kLocalMode>(
        &sum[0], &grad_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
  } else {
    kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
        &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
    kps::Reduce<AccT,
                kVItem,
                kBatchSize,
                1,
                kps::AddFunctor<AccT>,
                kps::details::ReduceMode::kLocalMode>(
        &sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
  }
462 463
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);

464 465 466
  // write result to global memory
  AccT out[kBatchSize][kLoopsV][kVSize];
  T out_tmp[kBatchSize][kLoopsV][kVSize];
467 468 469
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    if (i >= local_batches) break;
470 471
    AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
    AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
    if (LogMode) {
      kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpMulFunctor<AccT>>(
          &out[i][0][0], &srcptr[0], ExpMulFunctor<AccT>(sum[i]));
      kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::SubFunctor<AccT>>(
          &out_tmp[i][0][0],
          &gradptr[0],
          &out[i][0][0],
          kps::SubFunctor<AccT>());
    } else {
      kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
          &out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
      kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
          &out_tmp[i][0][0],
          &srcptr[0],
          &out[i][0][0],
          kps::MulFunctor<AccT>());
    }
489
    VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
490 491 492
    VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
    kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
        &dst_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
493 494 495 496 497
  }
}

#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT)                      \
  case Log2Elements:                                                       \
498 499 500 501
    WarpSoftmaxForward<T,                                                  \
                       VecT,                                               \
                       AccT,                                               \
                       Log2Elements,                                       \
502 503 504 505 506 507 508 509
                       LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
        dst, src, batch_size, stride, element_count);                      \
    break;

/*
  Wrapper of softmax formward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
510 511 512 513 514 515 516 517
void SwitchWarpSoftmaxForward(const int blocks,
                              const dim3 threads,
                              const GPUContext& dev_ctx,
                              T* dst,
                              const T* src,
                              const int batch_size,
                              const int stride,
                              const int element_count,
518
                              int Log2Elements) {
519
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
  switch (Log2Elements) {
    SOFTMAX_WARP_FORWARD_CASE(0, AccT);
    SOFTMAX_WARP_FORWARD_CASE(1, AccT);
    SOFTMAX_WARP_FORWARD_CASE(2, AccT);
    SOFTMAX_WARP_FORWARD_CASE(3, AccT);
    SOFTMAX_WARP_FORWARD_CASE(4, AccT);
    SOFTMAX_WARP_FORWARD_CASE(5, AccT);
    SOFTMAX_WARP_FORWARD_CASE(6, AccT);
    SOFTMAX_WARP_FORWARD_CASE(7, AccT);
    SOFTMAX_WARP_FORWARD_CASE(8, AccT);
    SOFTMAX_WARP_FORWARD_CASE(9, AccT);
    default:
      break;
  }
}

#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT)                      \
  case Log2Elements:                                                        \
538 539 540 541
    WarpSoftmaxBackward<T,                                                  \
                        VecT,                                               \
                        AccT,                                               \
                        Log2Elements,                                       \
542 543 544 545 546 547 548 549
                        LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
        dst, grad, src, batch_size, stride, element_count);                 \
    break;

/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
550 551 552 553 554 555 556 557 558 559 560
void SwitchWarpSoftmaxBackward(const int blocks,
                               const dim3 threads,
                               const GPUContext& dev_ctx,
                               T* dst,
                               const T* grad,
                               const T* src,
                               const int batch_size,
                               const int stride,
                               const int element_count,
                               int Log2Elements) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
  switch (Log2Elements) {
    SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(3, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(4, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(5, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(6, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
    default:
      break;
  }
}

#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

580 581 582 583 584
/**
 * <NormalSoftmaxKernel>
 * Better performence when axis != -1
 */

585 586 587 588
static void GetGridDim(
    int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) {
  int device_id = phi::backends::gpu::GetCurrentDeviceId();
  int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id);
589
  int max_threads_per_mp =
590
      phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id);
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
  int max_threads = max_threads_per_mp * max_mp;
  int num_threads = block.x * block.y;
  int max_num_blocks = max_threads / num_threads;

  int grid_x = (low_dim + block.x - 1) / block.x;
  grid_x = std::min(grid_x, max_num_blocks);
  int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
  grid_y = std::min(grid_y, high_dim);
  grid->x = grid_x;
  grid->y = grid_y;
}

static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
#ifdef __HIPCC__
  constexpr int max_num_threads = 256;
#else
  constexpr int max_num_threads = 1024;
#endif
609 610
  int block_x = 1 << Log2Ceil(low_dim);
  int block_y = 1 << Log2Ceil(mid_dim);
611 612 613 614 615
  block->x = std::min(block_x, 32);
  block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
  block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
}

616 617
static void GetLaunchConfig(
    int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) {
618 619 620 621
  GetBlockDim(mid_dim, low_dim, block);
  GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
}

622 623
template <typename T,
          typename AccT,
624
          template <typename, typename> class Functor>
625 626
__global__ void NormalSoftmaxForward(
    T* output, const T* input, int high_dim, int mid_dim, int low_dim) {
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
  using kMode = kps::details::ReduceMode;
  const int high_stride = mid_dim * low_dim;
  const int mid_stride = low_dim;
  for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
    for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
         low_id += blockDim.x * gridDim.x) {
      const int input_offset = high_id * high_stride + low_id;

      // 1. reduce max
      AccT max_value = -std::numeric_limits<AccT>::infinity();
      AccT value = -std::numeric_limits<AccT>::infinity();
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
        max_value = kps::MaxFunctor<AccT>()(max_value, value);
      }

      if (blockDim.y > 1) {
        kps::Reduce<AccT, 1, 1, 1, kps::MaxFunctor<AccT>, kMode::kGlobalMode>(
            &max_value, &max_value, kps::MaxFunctor<AccT>(), false);
      }

      // 2. reduce sum
      AccT sum = 0;
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
        sum += std::exp(value - max_value);
      }
      if (blockDim.y > 1) {
        kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
            &sum, &sum, kps::AddFunctor<AccT>(), false);
      }

      // 3. (log)softmax
      Functor<AccT, T> functor(max_value, sum);
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        int data_offset = input_offset + mid_id * mid_stride;
        output[data_offset] = functor(static_cast<AccT>(input[data_offset]));
      }
    }
  }
}

669 670
template <typename T,
          typename AccT,
671 672
          template <typename, typename> class Functor,
          bool LogMode>
673 674 675 676 677 678
__global__ void NormalSoftmaxBackward(T* input_grad,
                                      const T* output_grad,
                                      const T* output,
                                      int high_dim,
                                      int mid_dim,
                                      int low_dim) {
679 680 681 682 683 684 685 686 687 688
  using kMode = kps::details::ReduceMode;
  const int high_stride = mid_dim * low_dim;
  const int mid_stride = low_dim;
  for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
    for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
         low_id += blockDim.x * gridDim.x) {
      const int grad_offset = high_id * high_stride + low_id;

      // 1. reduce sum
      AccT sum = 0;
689 690 691 692 693 694 695 696 697 698 699
      if (LogMode) {
        for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
          int data_offset = grad_offset + mid_id * mid_stride;
          sum += static_cast<AccT>(output_grad[data_offset]);
        }
      } else {
        for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
          int data_offset = grad_offset + mid_id * mid_stride;
          sum += static_cast<AccT>(output_grad[data_offset]) *
                 static_cast<AccT>(output[data_offset]);
        }
700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
      }
      if (blockDim.y > 1) {
        kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
            &sum, &sum, kps::AddFunctor<AccT>(), false);
      }

      // 2. (log)softmax backward
      Functor<AccT, T> functor(sum);
      for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
        int data_offset = grad_offset + mid_id * mid_stride;
        input_grad[data_offset] =
            functor(static_cast<AccT>(output_grad[data_offset]),
                    static_cast<AccT>(output[data_offset]));
      }
    }
  }
}

718
template <typename T, bool LogMode = false>
719 720 721 722 723 724 725
void LaunchNormalSoftmaxForward(const GPUContext& dev_ctx,
                                T* output_data,
                                const T* input_data,
                                int high_dim,
                                int mid_dim,
                                int low_dim) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
726 727 728 729
  dim3 grid, block;
  GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
  if (LogMode) {
    NormalSoftmaxForward<
730 731
        T,
        AccT,
732 733 734 735
        LogSoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
        output_data, input_data, high_dim, mid_dim, low_dim);
  } else {
    NormalSoftmaxForward<
736 737 738
        T,
        AccT,
        SoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
739 740 741 742
        output_data, input_data, high_dim, mid_dim, low_dim);
  }
}

743
template <typename T, bool LogMode = false>
744 745 746 747 748 749 750 751
void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
                                 T* input_grad_data,
                                 const T* output_grad_data,
                                 const T* output_data,
                                 int high_dim,
                                 int mid_dim,
                                 int low_dim) {
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
752 753 754
  dim3 grid, block;
  GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
  if (LogMode) {
755 756 757 758
    NormalSoftmaxBackward<T,
                          AccT,
                          LogSoftmaxBackwardFunctor,
                          LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
759 760 761 762 763
        input_grad_data,
        output_grad_data,
        output_data,
        high_dim,
        mid_dim,
764 765
        low_dim);
  } else {
766 767 768 769
    NormalSoftmaxBackward<T,
                          AccT,
                          SoftmaxBackwardFunctor,
                          LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
770 771 772 773 774
        input_grad_data,
        output_grad_data,
        output_data,
        high_dim,
        mid_dim,
775 776 777 778
        low_dim);
  }
}

779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903
static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
                                             const int axis) {
  int dim = dims[axis];
  int N = phi::funcs::SizeToAxis(axis, dims);
  int D = phi::funcs::SizeOutAxis(axis, dims);
  return {N, dim, D, 1};
}

template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
                               const DenseTensor& x,
                               const int axis,
                               const bool log_mode,
                               DenseTensor* out) {
  auto* out_data = out->data<T>();

  const int rank = x.dims().size();
  std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);

  auto handle = dev_ctx.cudnn_handle();
  GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

  ScopedTensorDescriptor scoped_desc;
#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t desc =
      scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                               : MIOPEN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::miopenSoftmaxForward_V2(
      handle,
      paddle::platform::CudnnDataType<T>::kOne(),
      desc,
      x.data<T>(),
      paddle::platform::CudnnDataType<T>::kZero(),
      desc,
      out_data,
      algo,
      mode));
#else
  cudnnTensorDescriptor_t desc = scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                               : CUDNN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxForward(
      handle,
      algo,
      mode,
      paddle::platform::CudnnDataType<T>::kOne(),
      desc,
      x.data<T>(),
      paddle::platform::CudnnDataType<T>::kZero(),
      desc,
      out_data));
#endif
}

template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
                                const DenseTensor& out,
                                const DenseTensor& dout,
                                const int axis,
                                const bool log_mode,
                                DenseTensor* dx) {
  auto* dx_data = dx->data<T>();

  int rank = out.dims().size();
  std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);

  auto handle = dev_ctx.cudnn_handle();
  GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;

  ScopedTensorDescriptor scoped_desc;
#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t desc =
      scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                               : MIOPEN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::miopenSoftmaxBackward_V2(
          handle,
          paddle::platform::CudnnDataType<T>::kOne(),
          desc,
          out.data<T>(),
          desc,
          dout.data<T>(),
          paddle::platform::CudnnDataType<T>::kZero(),
          desc,
          dx_data,
          algo,
          mode));
#else
  cudnnTensorDescriptor_t desc = scoped_desc.descriptor<T>(layout, tensor_dims);
  auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                               : CUDNN_SOFTMAX_MODE_CHANNEL;
  auto algo = log_mode ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
  PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxBackward(
      handle,
      algo,
      mode,
      paddle::platform::CudnnDataType<T>::kOne(),
      desc,
      out.data<T>(),
      desc,
      dout.data<T>(),
      paddle::platform::CudnnDataType<T>::kZero(),
      desc,
      dx_data));
#endif
}

template <typename T>
static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) {
  if (dev_ctx.cudnn_handle() != nullptr) {
    if (std::is_same<T, phi::dtype::bfloat16>::value) {
#if CUDNN_VERSION < 8100
      return false;
#endif
    }
    return true;
  }
  return false;
}

904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929
#if CUDNN_VERSION < 8100
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
    const GPUContext& dev_ctx,
    const DenseTensor& x,
    const int axis,
    const bool log_mode,
    DenseTensor* out) {
  PADDLE_THROW(errors::Unavailable(
      "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
      "8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
    const GPUContext& dev_ctx,
    const DenseTensor& out,
    const DenseTensor& dout,
    const int axis,
    const bool log_mode,
    DenseTensor* dx) {
  PADDLE_THROW(errors::Unavailable(
      "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
      "8100."));
}
#endif

930
template <typename T, bool LogMode = false>
931 932 933 934
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
                                    const DenseTensor& x,
                                    const int input_axis,
                                    DenseTensor* out) {
935 936
  auto* out_data = out->data<T>();

937 938 939 940 941 942
  int rank = x.dims().size();
  int axis = phi::funcs::CanonicalAxis(input_axis, rank);
  std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
  int N = tensor_dims[0];
  int dim = tensor_dims[1];
  int D = tensor_dims[2];
943

Y
Yanxing Shi 已提交
944
  constexpr int max_dim = 512;
945

946 947 948 949 950 951
  if (D == 1 &&
      (!CanUseCudnnSoftmax<T>(dev_ctx) || (dim <= max_dim && sizeof(T) <= 4))) {
    int dim_log2 = static_cast<int>(Log2Ceil(dim));
    int dim_ceil = 1 << dim_log2;
    int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
    int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;
952 953 954 955

    // use 128 threads per block to maximimize gpu utilization
    constexpr int threads_per_block = 128;

956
    int warps_per_block = (threads_per_block / warp_size);
957 958
    int batches_per_block = warps_per_block * batches_per_warp;
    int blocks = (N + batches_per_block - 1) / batches_per_block;
959
    dim3 threads(warp_size, warps_per_block, 1);
960 961 962 963

    // vectorization read/write
    using T4 = typename VecT4<T>::Type;
    using T2 = typename VecT2<T>::Type;
964

965
    if (dim % 4 == 0) {
966 967 968 969 970 971 972 973
      SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
                                               threads,
                                               dev_ctx,
                                               out_data,
                                               x.data<T>(),
                                               N,
                                               dim,
                                               dim,
974
                                               dim_log2);
975
    } else if (dim % 2 == 0) {
976 977 978 979 980 981 982 983
      SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
                                               threads,
                                               dev_ctx,
                                               out_data,
                                               x.data<T>(),
                                               N,
                                               dim,
                                               dim,
984
                                               dim_log2);
985
    } else {
986 987 988 989 990 991 992 993
      SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
                                              threads,
                                              dev_ctx,
                                              out_data,
                                              x.data<T>(),
                                              N,
                                              dim,
                                              dim,
994
                                              dim_log2);
995
    }
996
  } else if (D > 1) {
997 998
    LaunchNormalSoftmaxForward<T, LogMode>(
        dev_ctx, out_data, x.data<T>(), N, dim, D);
999
  } else {
1000
    SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
1001 1002 1003 1004
  }
}

template <typename T, bool LogMode = false>
1005 1006 1007 1008 1009
void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
                                     const DenseTensor& out,
                                     const DenseTensor& dout,
                                     const int input_axis,
                                     DenseTensor* dx) {
1010 1011
  auto* dx_data = dx->data<T>();

1012 1013 1014 1015 1016 1017
  int rank = out.dims().size();
  int axis = phi::funcs::CanonicalAxis(input_axis, rank);
  std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
  int N = tensor_dims[0];
  int dim = tensor_dims[1];
  int D = tensor_dims[2];
1018

Y
Yanxing Shi 已提交
1019
  constexpr int max_dim = 512;
1020

1021 1022 1023 1024 1025 1026 1027
  if (D == 1 &&
      (!CanUseCudnnSoftmax<T>(dev_ctx) || (dim <= max_dim && sizeof(T) <= 4))) {
    int dim_log2 = Log2Ceil(dim);
    int dim_ceil = 1 << dim_log2;
    int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
    int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;

1028 1029
    constexpr int threads_per_block = 128;

1030
    int warps_per_block = (threads_per_block / warp_size);
1031 1032
    int batches_per_block = warps_per_block * batches_per_warp;
    int blocks = (N + batches_per_block - 1) / batches_per_block;
1033
    dim3 threads(warp_size, warps_per_block, 1);
1034 1035 1036 1037 1038

    // vectorization read/write
    using T4 = typename VecT4<T>::Type;
    using T2 = typename VecT2<T>::Type;
    if (dim % 4 == 0) {
1039 1040 1041 1042 1043 1044 1045 1046 1047
      SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
                                                threads,
                                                dev_ctx,
                                                dx_data,
                                                dout.data<T>(),
                                                out.data<T>(),
                                                N,
                                                dim,
                                                dim,
1048
                                                dim_log2);
1049
    } else if (dim % 2 == 0) {
1050 1051 1052 1053 1054 1055 1056 1057 1058
      SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
                                                threads,
                                                dev_ctx,
                                                dx_data,
                                                dout.data<T>(),
                                                out.data<T>(),
                                                N,
                                                dim,
                                                dim,
1059
                                                dim_log2);
1060
    } else {
1061 1062 1063 1064 1065 1066 1067 1068 1069
      SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
                                               threads,
                                               dev_ctx,
                                               dx_data,
                                               dout.data<T>(),
                                               out.data<T>(),
                                               N,
                                               dim,
                                               dim,
1070
                                               dim_log2);
1071
    }
1072
  } else if (D > 1) {
1073 1074
    LaunchNormalSoftmaxBackward<T, LogMode>(
        dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
1075
  } else {
1076
    SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
1077 1078 1079
  }
}

1080
}  // namespace phi