softmax_with_cross_entropy_op.cu 53.8 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
11 12 13 14 15 16 17
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
18
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
S
sneaxiy 已提交
19
#include "paddle/fluid/operators/math/cross_entropy.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
21
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/platform/for_range.h"
24
#include "paddle/phi/kernels/funcs/math_function.h"
25
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
26

C
caoying03 已提交
27 28 29
namespace paddle {
namespace operators {

30 31
#define ALIGN_BYTES 16

32 33
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
C
caoying03 已提交
34 35
using Tensor = framework::Tensor;

36
// Wrapper of log function. Use log(float32) for float16
37
template <typename T>
38 39 40 41 42 43 44 45 46 47 48 49 50 51
static __device__ __forceinline__ T Log(T x) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  AccT logx = std::log(static_cast<AccT>(x));
  return math::TolerableValue<T>()(static_cast<T>(logx));
}

// Wrapper of exp function. Use exp(float32) for float16
template <typename T>
static __device__ __forceinline__ T Exp(T x) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  AccT expx = std::exp(static_cast<AccT>(x));
  return math::TolerableValue<T>()(static_cast<T>(expx));
}

52 53 54 55 56 57 58 59 60 61 62 63
template <typename Tx, typename Ty = Tx>
struct ExpAddFunctor {
  HOSTDEVICE inline ExpAddFunctor(Tx max) : max(max) {}

  HOSTDEVICE inline Ty operator()(const Tx& sum, const Tx& x) const {
    return static_cast<Ty>(sum + std::exp(x - max));
  }

 private:
  Tx max;
};

64 65 66 67 68 69 70 71 72 73 74 75
// log2(value)
static inline int Log2Ceil(int value) {
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
}

enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy };

/*
  Hard label cross entropy.
*/
76
template <typename T, typename LabelT, bool IgnoreIndex>
77
__global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
78
                                      const LabelT* labels, const int n,
79 80 81 82 83 84 85 86
                                      const int dim, const int d,
                                      const int ignore_idx) {
  int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
  int64_t idx_n = ids / d;
  int64_t idx_d = ids % d;

  // thread ids compute loss[ids] using softmax[idx]
  if (ids < n * d) {
87 88
    auto lbl = static_cast<int64_t>(labels[ids]);
    if (lbl < 0) {  // label is negative
89 90
      loss[ids] = static_cast<T>(0.0);
    } else {  // label is positive of zero
91
      int64_t idx = idx_n * dim * d + lbl * d + idx_d;
92 93
      if (IgnoreIndex == true) {
        // IgnoreIndex is true
94
        if (lbl == ignore_idx) {
95 96 97 98
          loss[ids] = static_cast<T>(0.0);
        } else {
          loss[ids] = -Log(softmax[idx]);
        }
99
      } else {
100
        // IgnoreIndex is false
101 102 103 104 105 106 107 108 109 110 111
        loss[ids] = -Log(softmax[idx]);
      }
    }
  }
}

/*
  Hard label cross entropy with exp.
  Input: log softmax
  Output: loss and exp(input)
*/
112
template <typename T, typename LabelT, bool IgnoreIndex>
113
__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
114
                                         const LabelT* labels, const int n,
115 116 117 118 119 120 121 122 123
                                         const int dim, const int d,
                                         const int ignore_idx) {
  int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  int64_t idx_n = idx / (d * dim);
  int64_t idx_dim = (idx / d) % dim;
  int64_t idx_d = idx % d;
  int64_t ids = idx_n * d + idx_d;

  if (idx < n * dim * d) {
124
    auto lbl = static_cast<int64_t>(labels[ids]);
125 126
    if (IgnoreIndex == true) {
      // IgnoreIndex is true
127 128
      if (idx_dim == lbl) {
        if (lbl == ignore_idx) {
129 130 131 132 133 134 135
          loss[ids] = static_cast<T>(0.0);
        } else {
          loss[ids] = -softmax[idx];
        }
      }
    } else {
      // IgnoreIndex is false
136 137
      if (lbl >= 0 && lbl < dim) {
        if (lbl == idx_dim) {
138 139 140 141 142
          loss[ids] = -softmax[idx];
        }
      } else {
        loss[ids] = static_cast<T>(0.0);
      }
143
    }
144
    softmax[idx] = Exp(softmax[idx]);
145 146 147
  }
}

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
/*
  Core function of softmax with cross entropy forward
    - softmax, SoftmaxMode=kSoftmax
    - log softmax, SoftmaxMode=kLogSoftmax
    - softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy
  The computation includes
    - Compute max value: maxvalue_{i} = max_j src_{i,j}
    - Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
    - Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i}
    - Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i})
    - Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label)
  This computation results from following formula:
    softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}}
                  = e^{src_{i,j} - maxvalue_{i}}
                    / sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
                  = e^{src_{i,j} - maxvalue_{i}} / s_{i}
    logsoftmax_{i,j} = log(softmax_{i,j})
                     = src_{i,j} - maxvalue_{i} - log(s_{i})
  One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
  For reduction max (sum), firstly compute max (sum) to one warp, then use
  shuffle api to compute max (sum) in one warp.
*/
170 171
template <typename T, typename LabelT, typename VecT, typename AccT,
          int Log2Elements, SoftmaxMode mode, bool IgnoreIndex>
172
__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
173
                                   const LabelT* label, const int batch_size,
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
                                   const int stride, const int element_count,
                                   const int ignore_index) {
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kIterations = kDimCeil / kWarpSize;
  constexpr int kIterationsV =
      (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;

  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;

  // max index to read
  int idx_max_v[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
  }

  // read data from global memory
  AccT srcdata[kBatchSize][kIterationsV][kVSize];

#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
// read data to srcdata: - KVSize==1, - KVSize>1
#pragma unroll
    for (int it = 0; it < kIterationsV; ++it) {
      int src_idx = threadIdx.x + it * kWarpSize;
      if (kVSize == 1) {
        if (src_idx < idx_max_v[i]) {
          srcdata[i][it][0] =
              static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
        } else {
          srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
        }
      } else {
        const VecT* src_v =
            reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
        if (src_idx < idx_max_v[i]) {
          VecT srctmp = src_v[src_idx];
          const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
          for (int s = 0; s < kVSize; s++) {
            srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
          }
        } else {
#pragma unroll
          for (int s = 0; s < kVSize; s++) {
            srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
          }
        }
      }
    }
  }

  // compute max value: maxvalue_{i} = max_j src_{i,j}
  AccT max_value[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    // it = 0
    AccT valmax = srcdata[i][0][0];
#pragma unroll
    for (int s = 1; s < kVSize; ++s) {
      valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
    }
    max_value[i] = valmax;

// it = 1, 2, ...
#pragma unroll
    for (int it = 1; it < kIterationsV; ++it) {
      AccT valmax = srcdata[i][it][0];
#pragma unroll
      for (int s = 1; s < kVSize; ++s) {
        valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
      }
      max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
    }
  }
253
  phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292

  // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  AccT sum[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    // it = 0
    if (mode == SoftmaxMode::kLogSoftmax ||
        mode == SoftmaxMode::kCrossEntropy) {
      sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
    } else {
      srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
      sum[i] = srcdata[i][0][0];
    }
#pragma unroll
    for (int s = 1; s < kVSize; ++s) {
      if (mode == SoftmaxMode::kLogSoftmax ||
          mode == SoftmaxMode::kCrossEntropy) {
        sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
      } else {
        srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
        sum[i] += srcdata[i][0][s];
      }
    }

// it = 1, 2, ...
#pragma unroll
    for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll
      for (int s = 0; s < kVSize; ++s) {
        if (mode == SoftmaxMode::kLogSoftmax ||
            mode == SoftmaxMode::kCrossEntropy) {
          sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
        } else {
          srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
          sum[i] += srcdata[i][it][s];
        }
      }
    }
  }
293
  phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

// write data
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    if (mode == SoftmaxMode::kLogSoftmax ||
        mode == SoftmaxMode::kCrossEntropy) {
      sum[i] = std::log(sum[i]);
    }

#pragma unroll
    for (int it = 0; it < kIterationsV; ++it) {
      int idx = threadIdx.x + it * kWarpSize;
      if (kVSize == 1) {  // kVSize==1
        if (idx < idx_max_v[i]) {
          if (mode == SoftmaxMode::kLogSoftmax) {  // log softmax
            softmax[(first_batch + i) * stride + idx] =
                srcdata[i][it][0] - max_value[i] - sum[i];
            // softmax with cross entropy hard label
          } else if (mode == SoftmaxMode::kCrossEntropy) {
            AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i];
            // softmax
            softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax);
            // label
            int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
318
            auto lbl = static_cast<int64_t>(label[first_batch + i]);
319 320
            if (IgnoreIndex == true) {
              // IgnoreIndex is true
321 322
              if (lbl == loss_idx) {
                if (lbl != ignore_index) {
323 324 325 326 327 328 329
                  loss[first_batch + i] = -logsoftmax;
                } else {
                  loss[first_batch + i] = static_cast<T>(0.0);
                }
              }
            } else {
              // IgnoreIndex is false
330 331
              if (lbl >= 0 && lbl < element_count) {
                if (lbl == loss_idx) {
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
                  loss[first_batch + i] = -logsoftmax;
                }
              } else {
                loss[first_batch + i] = static_cast<T>(0.0);
              }
            }
          } else {  // softmax
            softmax[(first_batch + i) * stride + idx] =
                srcdata[i][it][0] / sum[i];
          }
        } else {
          break;
        }
      } else {  // KVSize>1
        VecT* softmax_v =
            reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
        VecT tmpdata;
        T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
        for (int s = 0; s < kVSize; ++s) {
          if (mode == SoftmaxMode::kLogSoftmax) {  // log softmax
            tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
            // softmax with cross entropy hard label
          } else if (mode == SoftmaxMode::kCrossEntropy) {
            AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i];
            // softmax
            tmpptr[s] = std::exp(logsoftmax);
            // label
            int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
361
            auto lbl = static_cast<int64_t>(label[first_batch + i]);
362 363
            if (IgnoreIndex == true) {
              // IgnoreIndex is true
364
              if (lbl == loss_idx && lbl != ignore_index) {
365 366 367 368
                loss[first_batch + i] = -logsoftmax;
              }
            } else {
              // IgnoreIndex is false
369 370
              if (lbl >= 0 && lbl < element_count) {
                if (lbl == loss_idx) {
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
                  loss[first_batch + i] = -logsoftmax;
                }
              } else {
                loss[first_batch + i] = static_cast<T>(0.0);
              }
            }
          } else {  // softmax
            tmpptr[s] = srcdata[i][it][s] / sum[i];
          }
        }
        if (idx < idx_max_v[i]) {
          softmax_v[idx] = tmpdata;
        } else {
          break;
        }
      }
    }
  }
}

391
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT)   \
392
  case Log2Elements:                                                  \
393
    WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode,     \
394 395 396 397 398 399 400 401
                       IgnoreIndex><<<blocks, threads, 0, stream>>>(  \
        loss, softmax, src, label, batch_size, stride, element_count, \
        ignore_index);                                                \
    break;

/*
  Wrapper of softmax with cross entropy forward hard label.
*/
402
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex>
403
void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
404
                              const LabelT* label, const int batch_size,
405 406 407 408 409
                              const int stride, const int element_count,
                              const int ignore_index, gpuStream_t stream) {
  using AccT = typename details::MPTypeTrait<T>::Type;

  // use 128 threads per block to maximimize gpu utilization
410 411
  const int log2_elements = static_cast<int>(Log2Ceil(element_count));
  const int kDimCeil = 1 << log2_elements;
412 413 414 415 416 417 418 419
  int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
  constexpr int threads_per_block = 128;
  int warps_per_block = (threads_per_block / kWarpSize);
  int batches_per_block = warps_per_block * batches_per_warp;
  int blocks = (batch_size + batches_per_block - 1) / batches_per_block;
  dim3 threads(kWarpSize, warps_per_block, 1);

420
  switch (log2_elements) {
421 422 423 424 425 426 427 428 429 430
    SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT);
    SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT);
431 432 433 434 435
    default:
      break;
  }
}

436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
template <typename T, bool IgnoreIndex>
__device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value,
                                            const int label_id,
                                            const int64_t label_value,
                                            const int tid, const int vec_size,
                                            const int offset,
                                            const int ignore_index) {
  int loss_id = vec_size * tid + offset;
  if (IgnoreIndex) {
    if (label_value == loss_id) {
      if (label_value == ignore_index) {
        loss[label_id] = static_cast<T>(0.0f);
      } else {
        loss[label_id] = loss_value;
      }
    }
  } else {
    if (label_value == loss_id) {
      loss[label_id] = loss_value;
    }
  }
}

template <typename T, typename AccT, int VecSize, class ReduceFunctor>
__device__ __forceinline__ AccT ThreadReduce(const T* input, int size,
                                             const int offset, AccT init,
                                             ReduceFunctor reducer) {
  using VecT = kps::details::VectorType<T, VecSize>;
  int tid = threadIdx.x;
  AccT val = init;

  if (offset > 0) {
    input -= offset;
    size += offset;
    if (tid >= offset) {
      val = reducer(val, input[tid]);
    }
    size -= blockDim.x;
    input += blockDim.x;
  }
  int remain = size % (VecSize * blockDim.x);

  T ins[VecSize];
  VecT* ins_vec = reinterpret_cast<VecT*>(&ins);

  // vector part
  for (; VecSize * tid < (size - remain); tid += blockDim.x) {
    *ins_vec = reinterpret_cast<const VecT*>(input)[tid];

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      val = reducer(val, ins[i]);
    }
  }

  // scalar part
  tid = size - remain + threadIdx.x;
  for (; tid < size; tid += blockDim.x) {
    val = reducer(val, input[tid]);
  }
  return val;
}

template <typename T, typename AccT, typename LabelT, int VecSize,
          bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
    T* loss, T* softmax, const T* logits, const LabelT* label, int size,
    const int offset, const LogSoftmaxForwardFunctor<AccT>& func,
    const int ignore_index) {
  using VecT = kps::details::VectorType<T, VecSize>;
  int tid = threadIdx.x;
  int label_id = blockIdx.x;
  auto label_value = static_cast<int64_t>(label[label_id]);
  const bool label_valid = label_value >= 0 && label_value < size;
  int loss_id_offset = 0;

  if (offset > 0) {
    logits -= offset;
    softmax -= offset;
    size += offset;
    loss_id_offset -= offset;
    if (tid >= offset) {
      AccT log_softmax = func(static_cast<AccT>(logits[tid]));
      softmax[tid] = static_cast<T>(std::exp(log_softmax));
      // loss
      if (label_valid) {
        ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
                                    label_id, label_value, tid, 1,
                                    loss_id_offset, ignore_index);
      }
    }
    size -= blockDim.x;
    logits += blockDim.x;
    softmax += blockDim.x;
    loss_id_offset += blockDim.x;
  }
  int remain = size % (VecSize * blockDim.x);

  T ins[VecSize];
  T outs[VecSize];
  VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
  VecT* outs_vec = reinterpret_cast<VecT*>(&outs);

  // vector part
  for (; VecSize * tid < (size - remain); tid += blockDim.x) {
    // read
    *ins_vec = reinterpret_cast<const VecT*>(logits)[tid];

#pragma unroll
    // compute
    for (int i = 0; i < VecSize; ++i) {
      AccT log_softmax = func(static_cast<AccT>(ins[i]));
      outs[i] = static_cast<T>(std::exp(log_softmax));

      // loss
      if (label_valid) {
        ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
                                    label_id, label_value, tid, VecSize,
                                    loss_id_offset + i, ignore_index);
      }
    }

    // write
    reinterpret_cast<VecT*>(softmax)[tid] = *outs_vec;
  }

  // scalar part
  tid = size - remain + threadIdx.x;
  for (; tid < size; tid += blockDim.x) {
    AccT log_softmax = func(static_cast<AccT>(logits[tid]));
    softmax[tid] = static_cast<T>(std::exp(log_softmax));

    // loss
    if (label_valid) {
      ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
                                  label_value, tid, 1, loss_id_offset,
                                  ignore_index);
    }
  }

  // invalid label, write once
  if (!label_valid && threadIdx.x == 0) {
    loss[label_id] = static_cast<T>(0.0f);
  }
}

template <typename T, typename AccT, typename LabelT, int VecSize,
          bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
    T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
    const LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
  int tid = threadIdx.x;
  int remain = size % (VecSize * blockDim.x);
  int label_id = blockIdx.x;
  auto label_value = static_cast<int64_t>(label[label_id]);
  const bool label_valid = label_value >= 0 && label_value < size;

  // main part
  for (; tid < (size - remain); tid += VecSize * blockDim.x) {
    T ins[VecSize];

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      ins[i] = logits[tid + i * blockDim.x];
    }
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      AccT log_softmax = func(static_cast<AccT>(ins[i]));
      softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
      // loss
      if (label_valid) {
        ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
                                    label_id, label_value, tid, VecSize, i,
                                    ignore_index);
      }
    }
  }

  // tail part
  for (; tid < size; tid += blockDim.x) {
    AccT log_softmax = func(static_cast<AccT>(logits[tid]));
    softmax[tid] = static_cast<T>(std::exp(log_softmax));
    // loss
    if (label_valid) {
      ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
                                  label_value, tid, 1, 0, ignore_index);
    }
  }

  // invalid label, write once
  if (!label_valid && threadIdx.x == 0) {
    loss[label_id] = static_cast<T>(0.0f);
  }
}

template <typename T, typename AccT, typename LabelT, int VecSize,
          bool IgnoreIndex>
__global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
                                         const LabelT* label,
                                         const int high_dim, const int mid_dim,
                                         const int ignore_index) {
  using VecT = kps::details::VectorType<T, VecSize>;

  // each block deal with one batch
  logits += blockIdx.x * mid_dim;
  softmax += blockIdx.x * mid_dim;

  const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
  const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);

  // 1. reduce max
  AccT max = ThreadReduce<T, AccT, VecSize, kps::MaxFunctor<AccT>>(
      logits, mid_dim, input_offset, -std::numeric_limits<AccT>::infinity(),
      kps::MaxFunctor<AccT>());
  max = kps::details::BlockXReduce<AccT, kps::MaxFunctor<AccT>>(
      max, kps::MaxFunctor<AccT>());

  // 2. reduce sum
  AccT sum = ThreadReduce<T, AccT, VecSize, ExpAddFunctor<AccT>>(
      logits, mid_dim, input_offset, static_cast<AccT>(0),
      ExpAddFunctor<AccT>(max));
  sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
      sum, kps::AddFunctor<AccT>());

  // 3. softmax
  LogSoftmaxForwardFunctor<AccT> func(max, sum);
  if (input_offset == output_offset) {
    VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
        loss, softmax, logits, label, mid_dim, input_offset, func,
        ignore_index);
  } else {
    ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
        loss, softmax, logits, label, mid_dim, func, ignore_index);
  }
}

template <typename T, typename LabelT, bool IgnoreIndex>
void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
                                    const LabelT* label, const int high_dim,
                                    const int mid_dim, const int ignore_index,
                                    gpuStream_t stream) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  constexpr int vec_size = sizeof(float4) / sizeof(T);
  const int max_num_threads = 1024;
  int max_block_size = std::min(mid_dim / vec_size, max_num_threads);
  if (vec_size > 1) {
    max_block_size /= 2;
  }

  int block_size = 1;
  while (block_size < max_block_size) {
    block_size *= 2;
  }
  block_size = std::max(block_size, kps::details::kWarpSize);
  dim3 grids(high_dim);
  dim3 blocks(block_size);
  VectorizedSoftmaxForward<T, AccT, LabelT, vec_size,
                           IgnoreIndex><<<grids, blocks, 0, stream>>>(
      loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
}

697 698
/*
  Wrapper of softmax with cross entropy hard label.
699 700 701
  - SwitchWarpSoftmaxForward for small size when axis == -1
  - LaunchVectorizedSoftmaxForward for large size when axis == -1
  - cudnn function for axis != -1
702
*/
703
template <typename T, typename LabelT, bool IgnoreIndex>
704 705
static void SoftmaxWithCrossEntropyHardLabel(
    const platform::CUDADeviceContext& ctx, int rank, int axis,
706
    const T* logits_data, const LabelT* labels_data, T* loss_data,
707 708 709
    T* softmax_data, int N, int dim, int D, const int ignore_index) {
  auto stream = ctx.stream();
  constexpr int max_dim = 320;
710 711 712 713 714 715 716 717 718 719 720
  if (D == 1) {
    if (dim <= max_dim) {  // small size
      const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
      SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
          loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
          ignore_index, stream);
    } else {  // large size
      LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(
          loss_data, softmax_data, logits_data, labels_data, N, dim,
          ignore_index, stream);
    }
721 722 723 724 725 726 727 728 729 730 731 732 733 734 735
  } else {
    ScopedTensorDescriptor desc;
    std::vector<int> tensor_dims = {N, dim, D, 1};
    DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
    miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
    cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif

    auto handle = ctx.cudnn_handle();

#ifdef PADDLE_WITH_HIP
    auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                                 : MIOPEN_SOFTMAX_MODE_CHANNEL;
736
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
737 738 739 740 741 742
        handle, platform::CudnnDataType<T>::kOne(), descp, logits_data,
        platform::CudnnDataType<T>::kZero(), descp, softmax_data,
        MIOPEN_SOFTMAX_LOG, mode));
#else
    auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                                 : CUDNN_SOFTMAX_MODE_CHANNEL;
743
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
744 745 746 747 748 749 750
        handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
        descp, logits_data, platform::CudnnDataType<T>::kZero(), descp,
        softmax_data));
#endif
    int threads = 128;
    int blocks = (N * dim * D + threads - 1) / threads;
    // compute cross entropy, input is log softmax
751 752
    CrossEntropyExpHardLabel<T, LabelT,
                             IgnoreIndex><<<blocks, threads, 0, stream>>>(
753 754 755 756 757 758 759
        loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
  }
}

/*
  Wrapper of softmax with cross entropy grad hard label.
*/
760
template <typename T, typename LabelT>
761
__global__ void SoftmaxWithCrossEntropyGradHardLabel(
762
    T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n,
763 764 765 766 767 768 769 770
    const int64_t dim, const int64_t d, const int ignore_index) {
  int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  int64_t idx_n = idx / (d * dim);
  int64_t idx_dim = (idx / d) % dim;
  int64_t idx_d = idx % d;
  int64_t ids = idx_n * d + idx_d;

  if (idx < n * dim * d) {
771 772
    auto lbl = static_cast<int64_t>(labels[ids]);
    if (lbl == ignore_index) {
773
      logits_grad[idx] = static_cast<T>(0.0);
774
    } else if (lbl == idx_dim) {
775 776
      logits_grad[idx] =
          (logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids];
777
    } else {
778
      logits_grad[idx] *= loss_grad[ids];
779
    }
780 781 782
  }
}

783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850
/*
  Cross entropy soft label with dynamic size on axis (log2_elements is
  varibale).
  - if the input is softmax,compute loss with softmax
  - if the input is log_softmax, compute loss with log_softmax and update
  softmax
*/
template <typename T, typename VecT, bool InLogMode = false>
__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax,
                                      const T* labels, const int n,
                                      const int dim, const int d,
                                      int log2_elements) {
  const int kDimCeil = 1 << log2_elements;
  const int kVSize = sizeof(VecT) / sizeof(T);

#ifdef __HIPCC__
  const int kThreadPerBlock = 256;
#else
  const int kThreadPerBlock = 512;
#endif
  const int kBatchPerBlock = 1;
  const int kWarpSize = 32;  // (dim < 32) ? dim : 32;
  const int kBatchSize = 1;
  const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock;
  const int kWarpPerBatch = kThreadPerBatch / kWarpSize;

  const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch;
  const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1;

  const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;

  T sum[kBatchSize]{static_cast<T>(0.0)};
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    int ids = first_batch + i;
    if (ids >= n * d) break;
    int idx_n = ids / d;
    int idx_d = ids % d;
#pragma unroll
    for (int it = 0; it < kIterations; ++it) {
      int idx_dim = it * kThreadPerBatch + threadIdx.x;
      int idx = idx_n * dim * d + idx_dim * d + idx_d;

      if (idx_n < n && idx_dim < dim) {
        VecT softmaxdata;
        if (InLogMode) {
          softmaxdata = reinterpret_cast<VecT*>(&softmaxwrt[idx])[0];
        } else {
          softmaxdata = reinterpret_cast<const VecT*>(&softmax[idx])[0];
        }
        VecT labelsdata = reinterpret_cast<const VecT*>(&labels[idx])[0];
        T* softmaxptr = reinterpret_cast<T*>(&softmaxdata);
        T* labelsptr = reinterpret_cast<T*>(&labelsdata);
#pragma unroll
        for (int s = 0; s < kVSize; s++) {
          if (InLogMode) {
            sum[i] -= softmaxptr[s] * labelsptr[s];
            softmaxptr[s] = Exp(softmaxptr[s]);
          } else {
            sum[i] -= Log(softmaxptr[s]) * labelsptr[s];
          }
        }
        if (InLogMode) {
          reinterpret_cast<VecT*>(&softmaxwrt[idx])[0] = softmaxdata;
        }
      }
    }
  }
851
  phi::WarpReduceSum<T, kBatchSize, kWarpSize>(sum);
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958
  __syncthreads();

  __shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize];
  if (threadIdx.x % kWarpSize == 0) {
#pragma unroll
    for (int i = 0; i < kBatchSize; i++) {
      sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i];
    }
  }
  __syncthreads();

  // write
  if (threadIdx.x == 0) {
    for (int i = 0; i < kBatchSize; i++) {
      int ids = first_batch + i;
      if (ids < n * d) {
        loss[ids] = sumshare[0][threadIdx.y][i];
        for (int s = 1; s < kWarpPerBatch; s++) {
          loss[ids] += sumshare[s][threadIdx.y][i];
        }
      }
    }
  }
}

/*
Core function of softmax with cross entropy forward soft label.
The computation includes
  - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
  - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
                                            const T* label,
                                            const int batch_size,
                                            const int stride,
                                            const int element_count) {
  const bool LogMode = true;

  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kIterations = kDimCeil / kWarpSize;
  constexpr int kIterationsV =
      (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;

  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
  int local_batches = batch_size - first_batch;
  if (local_batches > kBatchSize) {
    local_batches = kBatchSize;
  }

  // read data from global memory
  VecT srcdata[kBatchSize][kIterationsV];
  VecT labeldata[kBatchSize][kIterationsV];

  for (int i = 0; i < kBatchSize; ++i) {
    const VecT* src_v =
        reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
    const VecT* label_v =
        reinterpret_cast<const VecT*>(&label[(first_batch + i) * stride]);

    // max index to read
    int idx_max = (i < local_batches) ? element_count : 0;
    int idx_max_v = idx_max / kVSize;

    // read data
    for (int it = 0; it < kIterationsV; ++it) {
      int src_idx = threadIdx.x + it * kWarpSize;
      if (src_idx < idx_max_v) {
        srcdata[i][it] = src_v[src_idx];
        labeldata[i][it] = label_v[src_idx];
      } else {
#pragma unroll
        for (int s = 0; s < kVSize; s++) {
          reinterpret_cast<T*>(&srcdata[i][it])[s] =
              -std::numeric_limits<AccT>::max();
          reinterpret_cast<T*>(&labeldata[i][it])[s] = 0.0;
        }
      }
    }
  }

  // compute max value
  AccT max_value[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    max_value[i] = -std::numeric_limits<AccT>::infinity();
#pragma unroll
    for (int it = 0; it < kIterationsV; ++it) {
      T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
      T valmax = srcptr_v[0];
#pragma unroll
      for (int s = 1; s < kVSize; ++s) {
        valmax = (valmax > srcptr_v[s]) ? valmax : srcptr_v[s];
      }
      max_value[i] = (max_value[i] > static_cast<AccT>(valmax))
                         ? max_value[i]
                         : static_cast<AccT>(valmax);
    }
  }
959
  phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978

  // compute sum
  AccT sum[kBatchSize]{0.0};
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
#pragma unroll
    for (int it = 0; it < kIterationsV; ++it) {
      T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
#pragma unroll
      for (int s = 0; s < kVSize; ++s) {
        if (LogMode) {
          sum[i] += std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
        } else {
          srcptr_v[s] = std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
          sum[i] += static_cast<AccT>(srcptr_v[s]);
        }
      }
    }
  }
979
  phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021

  // log_softmax and loss
  AccT sumloss[kBatchSize]{0.0};
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    if (i >= local_batches) break;

    VecT* softmax_v =
        reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);

    // max index to write
    int idx_max = (i < local_batches) ? element_count : 0;
    int idx_max_v = idx_max / kVSize;

    if (LogMode) {
      sum[i] = std::log(sum[i]);
    }
#pragma unroll
    for (int it = 0; it < kIterationsV; ++it) {
      T* srcvp = reinterpret_cast<T*>(&srcdata[i][it]);
      T* labelvp = reinterpret_cast<T*>(&labeldata[i][it]);
      VecT tmpv;
      T* tmpvp = reinterpret_cast<T*>(&tmpv);
#pragma unroll
      for (int s = 0; s < kVSize; ++s) {
        if (LogMode) {
          AccT logsoftmax = static_cast<AccT>(srcvp[s]) - max_value[i] - sum[i];
          sumloss[i] -= logsoftmax * static_cast<AccT>(labelvp[s]);
          tmpvp[s] = std::exp(logsoftmax);
        } else {
          tmpvp[s] = static_cast<AccT>(srcvp[s]) / sum[i];
        }
      }

      int idx = threadIdx.x + it * kWarpSize;
      if (idx < idx_max_v) {
        softmax_v[idx] = tmpv;
      }
    }
  }

  // loss
1022
  phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144

  for (int i = 0; i < kBatchSize; i++) {
    if (i >= local_batches) break;
    loss[first_batch + i] = sumloss[i];
  }
}

#define SOFTMAX_WARP_FORWARD_SOFT_CASE(Log2Elements, VecT, AccT)               \
  case Log2Elements:                                                           \
    WarpSoftmaxForwardSoftLabel<T, VecT, AccT,                                 \
                                Log2Elements><<<blocks, threads, 0, stream>>>( \
        loss, softmax, src, label, batch_size, stride, element_count);         \
    break;

/*
  Wrapper of softmax with cross entropy forward soft label.
*/
template <typename T>
void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, const dim3 threads,
                                       gpuStream_t stream, T* loss, T* softmax,
                                       const T* src, const T* label,
                                       const int batch_size, const int stride,
                                       const int element_count,
                                       const int log2_elements) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  switch (log2_elements) {
    SOFTMAX_WARP_FORWARD_SOFT_CASE(0, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(1, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(2, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(3, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(4, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(5, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(6, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(7, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(8, T, AccT);
    SOFTMAX_WARP_FORWARD_SOFT_CASE(9, T, AccT);
    default:
      break;
  }
}

template <typename T>
static void SoftmaxWithCrossEntropySoftLabel(
    const platform::CUDADeviceContext& ctx, const int rank, const int axis,
    const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
    int N, int dim, int D) {
#ifdef __HIPCC__
  constexpr int kMaxBlockDim = 256;
#else
  constexpr int kMaxBlockDim = 512;
#endif
  int64_t block_dim = dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(dim)));

  int64_t grid_dim = N * D;
  constexpr int max_dim = 320;

  const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
  const int kDimCeil = 1 << kDimLog2;
  auto stream = ctx.stream();

  if (D == 1 && dim <= max_dim) {
    int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
    int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;

    // use 128 threads per block to maximimize gpu utilization
    constexpr int threads_per_block = 128;
    int warps_per_block = (threads_per_block / kWarpSize);
    int batches_per_block = warps_per_block * batches_per_warp;
    int blocks = (N + batches_per_block - 1) / batches_per_block;
    dim3 threads(kWarpSize, warps_per_block, 1);

    SwitchWarpSoftmaxForwardSoftLabel<T>(blocks, threads, stream, loss_data,
                                         softmax_data, logits_data, labels_data,
                                         N, dim, dim, kDimLog2);

  } else {
    ScopedTensorDescriptor desc;
    std::vector<int> tensor_dims = {N, dim, D, 1};
    DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
    miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
    cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif

    auto handle = ctx.cudnn_handle();

#ifdef PADDLE_WITH_HIP
    auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                                 : MIOPEN_SOFTMAX_MODE_CHANNEL;
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
        handle, platform::CudnnDataType<T>::kOne(), descp, logits_data,
        platform::CudnnDataType<T>::kZero(), descp, softmax_data,
        MIOPEN_SOFTMAX_LOG, mode));
#else
    auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                                 : CUDNN_SOFTMAX_MODE_CHANNEL;
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
        handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
        descp, logits_data, platform::CudnnDataType<T>::kZero(), descp,
        softmax_data));
#endif

    const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
    const int kDimCeil = 1 << kDimLog2;
#ifdef __HIPCC__
    int kThreadPerBlock = 256;
#else
    int kThreadPerBlock = 512;
#endif

    int kBatchPerBlock = 1;
    int blocks = (N * D + kBatchPerBlock - 1) / kBatchPerBlock;
    dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);

    CrossEntropySoftLabel<T, T, true><<<blocks, threads, 0, stream>>>(
        loss_data, softmax_data, NULL, labels_data, N, dim, D, kDimLog2);
  }
}

1145 1146 1147
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
1148 1149 1150 1151
                                               const T* labels, const int64_t n,
                                               const int64_t d,
                                               const int64_t remain) {
  int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
1152
  if (ids < n * d) {
1153 1154 1155
    int64_t idx_n = ids / d;
    int64_t idx_remain = ids % remain;
    int64_t idx_loss = idx_n * remain + idx_remain;
1156
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
1157
  }
C
caoying03 已提交
1158
}
S
sneaxiy 已提交
1159

1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
template <typename T>
__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
                                                    const T* loss_grad,
                                                    const T* labels,
                                                    const int n, const int d,
                                                    const int remain) {
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]);
  }
}

1175
template <typename T, typename LabelT>
1176
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
1177
                                                    const LabelT* labels,
1178 1179 1180 1181 1182 1183
                                                    const int n, const int d,
                                                    const int remain,
                                                    const int ignore_index) {
  CUDA_KERNEL_LOOP(index, n * remain) {
    int idx_n = index / remain;
    int idx_remain = index % remain;
1184
    int tmp = static_cast<int>(labels[index]);
1185 1186 1187 1188 1189 1190 1191
    int idx = idx_n * d + tmp * remain + idx_remain;
    if (ignore_index != tmp) {
      logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
    }
  }
}

1192
template <typename T, typename LabelT>
1193 1194 1195
__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
                                          const int num, const int d,
                                          const int remain,
1196
                                          const LabelT* labels,
1197 1198 1199 1200 1201 1202
                                          const int ignore_index) {
  CUDA_KERNEL_LOOP(index, num) {
    int idx_n = index / d;
    int idx_remain = index % remain;
    int idx_lbl = idx_n * remain + idx_remain;
    int k = (index % d) / remain;
1203 1204
    auto lbl = static_cast<int64_t>(labels[idx_lbl]);
    if (lbl == ignore_index || lbl != k) {
1205 1206 1207 1208 1209 1210 1211
      logit_grad[index] = static_cast<T>(0.);
    } else {
      logit_grad[index] *= loss_grad[idx_lbl];
    }
  }
}

C
caoying03 已提交
1212
template <typename T>
Y
Yu Yang 已提交
1213
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
1214 1215
 public:
  void Compute(const framework::ExecutionContext& context) const override {
1216 1217 1218 1219 1220 1221
    RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
  }

  template <typename LabelT>
  static void Apply(const framework::ExecutionContext& context,
                    const framework::Tensor& labels, const bool soft_label) {
1222 1223 1224 1225
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
1226
    const bool use_softmax = context.Attr<bool>("use_softmax");
1227 1228

    // do not with softmax op, and input is softmax
1229
    if (!use_softmax) {
1230 1231 1232 1233 1234
      const Tensor* softmax = context.Input<Tensor>("Logits");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");

      const int rank = softmax->dims().size();
1235 1236
      const int axis =
          phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
1237
      const int axis_dim = softmax->dims()[axis];
1238

1239 1240
      const int n = phi::funcs::SizeToAxis(axis, softmax->dims());
      const int d = phi::funcs::SizeFromAxis(axis, softmax->dims());
1241

1242 1243 1244
      auto* softmax_out_data =
          softmax_out->template mutable_data<T>(context.GetPlace());
      auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
1245

1246
      phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_constant;
1247
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
1248 1249 1250 1251 1252 1253 1254 1255 1256 1257
      if (axis_dim == 1) {
        set_constant(context.cuda_device_context(), softmax_out,
                     static_cast<T>(1));
        return;
      }

      auto ignore_index = context.Attr<int>("ignore_index");

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
1258
      labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271
      loss_2d.ShareDataWith(*loss).Resize({n, 1});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      // math::CrossEntropyFunctor support axis is the last
      if (axis == -1) {
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
            soft_label, ignore_index, axis_dim);
        return;
      }

      // if axis is not the last, we need a new impliment
      if (soft_label) {
1272 1273
        auto* logits_data = softmax->template data<T>();
        auto* labels_data = labels.template data<T>();
1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289

        const int kDimLog2 = static_cast<int>(Log2Ceil(axis_dim));
        const int kDimCeil = 1 << kDimLog2;
#ifdef __HIPCC__
        int kThreadPerBlock = 256;
#else
        int kThreadPerBlock = 512;
#endif
        int kBatchPerBlock = 1;
        int blocks = (n * d + kBatchPerBlock - 1) / kBatchPerBlock;
        dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);

        CrossEntropySoftLabel<T, T, false><<<
            blocks, threads, 0, context.cuda_device_context().stream()>>>(
            loss_data, NULL, logits_data, labels_data, n, axis_dim,
            d / axis_dim, kDimLog2);
1290
      } else {  // HardLabel
1291 1292
        auto* logits_data = softmax->template data<T>();
        auto* labels_data = labels.template data<LabelT>();
1293 1294 1295
        int threads = 128;
        int blocks = (n * d / axis_dim + threads - 1) / threads;
        if (ignore_index >= 0 && ignore_index < axis_dim) {
1296
          CrossEntropyHardLabel<T, LabelT, true><<<
1297 1298 1299 1300
              blocks, threads, 0, context.cuda_device_context().stream()>>>(
              loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
              ignore_index);
        } else {
1301
          CrossEntropyHardLabel<T, LabelT, false><<<
1302 1303 1304 1305
              blocks, threads, 0, context.cuda_device_context().stream()>>>(
              loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
              ignore_index);
        }
1306 1307 1308 1309 1310 1311 1312 1313 1314 1315
      }

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
1316 1317
    const Tensor* logits = context.Input<Tensor>("Logits");
    Tensor* softmax = context.Output<Tensor>("Softmax");
1318
    Tensor* loss = context.Output<Tensor>("Loss");
1319 1320

    const int rank = logits->dims().size();
1321
    const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
1322 1323
    int axis_dim = logits->dims()[axis];

1324 1325
    const int64_t n = phi::funcs::SizeToAxis(axis, logits->dims());
    const int64_t d = phi::funcs::SizeFromAxis(axis, logits->dims());
1326

1327 1328
    auto* softmax_data = softmax->template mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
1329

1330
    if (axis_dim == 1) {
1331
      phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_constant;
1332 1333 1334 1335 1336
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

1337
    auto ignore_index = context.Attr<int>("ignore_index");
1338

S
sneaxiy 已提交
1339
    if (soft_label) {
1340 1341
      auto* logits_data = logits->template data<T>();
      auto* labels_data = labels.template data<T>();
1342 1343 1344
      SoftmaxWithCrossEntropySoftLabel<T>(
          context.cuda_device_context(), rank, axis, logits_data, labels_data,
          softmax_data, loss_data, n, axis_dim, d / axis_dim);
S
sneaxiy 已提交
1345
    } else {
S
sneaxiy 已提交
1346
      if (!context.Attr<bool>("numeric_stable_mode")) {
1347 1348 1349 1350
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
1351
        labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
1352
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
1353 1354
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
1355
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
1356
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
1357
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
1358
      } else {
1359 1360
        auto* logits_data = logits->template data<T>();
        auto* labels_data = labels.template data<LabelT>();
1361
        if (ignore_index >= 0 && ignore_index < axis_dim) {
1362
          SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(
1363 1364 1365 1366
              context.cuda_device_context(), rank, axis, logits_data,
              labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
              ignore_index);
        } else {
1367
          SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(
1368 1369 1370 1371
              context.cuda_device_context(), rank, axis, logits_data,
              labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
              ignore_index);
        }
S
sneaxiy 已提交
1372
      }
S
sneaxiy 已提交
1373
    }
C
caoying03 已提交
1374 1375 1376 1377
  }
};

template <typename T>
Y
Yu Yang 已提交
1378
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
1379 1380
 public:
  void Compute(const framework::ExecutionContext& context) const override {
1381 1382 1383 1384 1385 1386
    RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
  }

  template <typename LabelT>
  static void Apply(const framework::ExecutionContext& context,
                    const framework::Tensor& labels, const bool soft_label) {
1387 1388 1389 1390
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
1391
    const T* loss_grad_data =
1392 1393
        context.Input<Tensor>(framework::GradVarName("Loss"))
            ->template data<T>();
C
caoying03 已提交
1394 1395
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
1396 1397 1398 1399 1400
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
1401
    T* logit_grad_data = logit_grad->template data<T>();
C
caoying03 已提交
1402

1403
    const int rank = logit_grad->dims().size();
1404
    const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
1405 1406
    int axis_dim = logit_grad->dims()[axis];

1407 1408
    const int64_t n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
    const int64_t d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
1409
    const int64_t remain = d / axis_dim;
1410

1411 1412 1413
#ifdef __HIPCC__
    int block = 256;
#else
1414
    int block = 512;
1415
#endif
1416
    auto stream = context.cuda_device_context().stream();
1417
    auto ignore_index = context.Attr<int>("ignore_index");
1418
    auto use_softmax = context.Attr<bool>("use_softmax");
1419 1420

    // do not with softmax op, and input is softmax
1421
    if (!use_softmax) {
1422
      if (soft_label) {
1423
        int grid = (n * d + block - 1) / block;
1424
        const T* label_data = labels.template data<T>();
1425 1426 1427 1428 1429 1430
        SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
            logit_grad_data, loss_grad_data, label_data, n, d, remain);
      } else {
        Tensor logits_grad_2d;
        logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
        int grid = (n * remain + block - 1) / block;
1431 1432 1433
        const auto* label_data = labels.template data<LabelT>();
        HardLabelCrossEntropyGradientKernel<T,
                                            LabelT><<<grid, block, 0, stream>>>(
1434 1435 1436
            logit_grad_data, label_data, n, d, remain, ignore_index);
        int num = n * d;
        grid = (num + block - 1) / block;
1437
        ScaleCrossEntropyGradient<T, LabelT><<<grid, block, 0, stream>>>(
1438 1439 1440 1441 1442 1443 1444 1445 1446
            logit_grad_data, loss_grad_data, num, d, remain, label_data,
            ignore_index);
      }

      return;
    }

    // with softmax, continue

1447
    if (soft_label) {
1448
      int64_t grid = (n * d + block - 1) / block;
1449
      const T* label_data = labels.template data<T>();
1450
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
1451
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
1452
    } else {
1453
      const auto* label_data = labels.template data<LabelT>();
1454 1455 1456 1457
      int grid = (n * d + block - 1) / block;
      SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
          logit_grad_data, loss_grad_data, label_data, n, d / remain, remain,
          ignore_index);
1458
    }
C
caoying03 已提交
1459 1460 1461 1462 1463 1464 1465
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
1466 1467 1468 1469 1470 1471 1472 1473 1474 1475
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>);
#else
1476 1477 1478 1479 1480 1481 1482 1483 1484
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);
1485
#endif