softmax_with_cross_entropy_op.cu 27.2 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
11 12 13 14 15 16 17
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
S
sneaxiy 已提交
18
#include "paddle/fluid/operators/math/cross_entropy.h"
19
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
21
#include "paddle/fluid/platform/for_range.h"
22

C
caoying03 已提交
23 24 25 26 27
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

28
namespace {
C
caoying03 已提交
29
template <typename T>
30
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
31 32 33 34 35 36
                                 const int64_t n, const int64_t d,
                                 const int64_t remain, const int ignore_index) {
  CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) {
    int64_t idx_n = index / remain;
    int64_t idx_remain = index % remain;
    int64_t tmp = labels[index];
37
    if (ignore_index != tmp) {
38
      int64_t idx = idx_n * d + tmp * remain + idx_remain;
39 40
      logit_grad[idx] -= static_cast<T>(1.);
    }
Y
Yu Yang 已提交
41
  }
42
}
Y
Yu Yang 已提交
43

44
template <typename T>
45 46 47 48 49 50 51
__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num,
                      const int64_t d, const int64_t remain,
                      const int64_t* labels, const int ignore_index) {
  CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) {
    int64_t idx_n = index / d;
    int64_t idx_remain = index % remain;
    int64_t idx_lbl = idx_n * remain + idx_remain;
52 53 54 55 56
    if (labels[idx_lbl] == ignore_index) {
      logit_grad[index] = static_cast<T>(0.);
    } else {
      logit_grad[index] *= loss_grad[idx_lbl];
    }
57 58 59 60 61 62
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
63 64 65 66
                                               const T* labels, const int64_t n,
                                               const int64_t d,
                                               const int64_t remain) {
  int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
67
  if (ids < n * d) {
68 69 70
    int64_t idx_n = ids / d;
    int64_t idx_remain = ids % remain;
    int64_t idx_loss = idx_n * remain + idx_remain;
71
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
72
  }
C
caoying03 已提交
73
}
S
sneaxiy 已提交
74

75
}  // namespace
C
caoying03 已提交
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
92 93
  return math::TolerableValue<float>()(logf(x));
}
94
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
T
tianshuo78520a 已提交
116
Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
S
sneaxiy 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

138
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
139 140
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
141
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
142
                                          int64_t d, int axis_dim) {
S
sneaxiy 已提交
143 144
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

145 146 147
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
148 149 150 151 152
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
153

154
  int64_t step = BlockDim * remain;
S
sneaxiy 已提交
155
  T cur_max = logits_data[beg_idx];
156
  beg_idx += step;
S
sneaxiy 已提交
157 158 159 160
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
161
    beg_idx += step;
S
sneaxiy 已提交
162 163 164 165
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

166
  if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
S
sneaxiy 已提交
167 168
}

169
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
170 171
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
172 173
                                                 T* max_data, T* softmax,
                                                 int64_t d, int axis_dim) {
S
sneaxiy 已提交
174 175
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

176 177 178
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
179 180 181 182 183
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
184 185

  auto block_max = max_data[blockIdx.x];
186
  int64_t step = BlockDim * remain;
S
sneaxiy 已提交
187

188 189 190 191 192 193
  // In numeric stable mode softmax_with_loss, we calc loss with
  // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
  // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
  // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
  // be 1.0 and 0.0, represent prob is 1.0 and 0.0.
  // So there is no need to clip on shift_softmax.
S
sneaxiy 已提交
194
  softmax[beg_idx] = logits_data[beg_idx] - block_max;
195
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
196
  auto idx = beg_idx + step;
S
sneaxiy 已提交
197 198
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
199
    diff_max_sum += exp_on_device(softmax[idx]);
200
    idx += step;
S
sneaxiy 已提交
201 202 203 204
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
205
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
206 207 208 209 210

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
211
  beg_idx += step;
S
sneaxiy 已提交
212 213
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
214
    beg_idx += step;
S
sneaxiy 已提交
215
  }
216 217 218 219

  // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
  // calculate diff_max_sum, __syncthreads() is needed here.
  __syncthreads();
S
sneaxiy 已提交
220
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
221 222
}

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
#ifdef __HIPCC__  // @{ HIP Seperate Kernel for RowReductionForDiffMaxSum
// Note(qili93): HIP do not support return in kernel, need to seperate
// RowReductionForDiffMaxSum into two kernels below
template <typename T, int BlockDim>
static __global__ void RowReductionForSum(const T* logits_data, T* max_data,
                                          T* softmax, int64_t d, int axis_dim) {
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;

  auto block_max = max_data[blockIdx.x];
  int64_t step = BlockDim * remain;

  softmax[beg_idx] = logits_data[beg_idx] - block_max;
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
  auto idx = beg_idx + step;
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
    diff_max_sum += exp_on_device(softmax[idx]);
    idx += step;
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
}

template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiff(const T* logits_data, T* max_data,
                                           T* softmax, int d, int axis_dim) {
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
  int step = BlockDim * remain;

  T diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
  beg_idx += step;
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
    beg_idx += step;
  }

  __syncthreads();
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
}
#endif  // @} End HIP Seperate Kernel for RowReductionForDiffMaxSum

277
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
278
template <typename T, int BlockDim>
S
sneaxiy 已提交
279
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
280 281
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
    int64_t d, int axis_dim) {
S
sneaxiy 已提交
282 283
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

284 285 286
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
287 288 289 290 291
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
292 293 294 295

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
296
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
297
  auto loss = -labels_data[beg_idx] * tmp;
298
  int64_t step = BlockDim * remain;
299
  beg_idx += step;
S
sneaxiy 已提交
300 301
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
302
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
303
    loss -= (labels_data[beg_idx] * tmp);
304
    beg_idx += step;
S
sneaxiy 已提交
305 306 307 308 309 310 311
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
312 313
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
314
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
315
                                          T* log_softmax, int64_t d,
316
                                          int axis_dim, int ignore_idx)
317
      : labels_(labels),
S
sneaxiy 已提交
318 319
        loss_(loss),
        log_softmax_(log_softmax),
320
        d_(d),
321 322
        axis_dim_(axis_dim),
        ignore_idx_(ignore_idx) {}
S
sneaxiy 已提交
323

324
  __device__ void operator()(int64_t idx) const {
325
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
326 327 328 329
    int64_t remain = d_ / axis_dim_;
    int64_t idx_n = idx / d_;
    int64_t idx_axis = (idx % d_) / remain;
    int64_t idx_remain = idx % remain;
330
    // labels, loss view as [n, remain]
331
    int64_t idx_lbl = idx_n * remain + idx_remain;
332 333 334
    PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
                       labels_[idx_lbl] == ignore_idx_,
                   "The value of label[%ld] expected >= 0 and < %ld, or == %d,"
G
Guanghua Yu 已提交
335
                   "but got %ld. Please check input value.",
336
                   idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
337
    // It also would ignore labels not in range(class_num).
338
    if (idx_axis != labels_[idx_lbl]) {
339
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
340 341
    } else {
      auto softmax = log_softmax_[idx];
342
      log_softmax_[idx] = exp_on_device(softmax);
343
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
344 345 346 347 348 349 350
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
351
  int64_t d_;
352
  int axis_dim_;
353
  int ignore_idx_;
S
sneaxiy 已提交
354 355 356 357 358
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
359
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
360
                                                       T* loss, T* log_softmax,
361
                                                       int64_t d, int axis_dim,
S
sneaxiy 已提交
362
                                                       int ignore_idx)
363
      : labels_(labels),
S
sneaxiy 已提交
364 365
        loss_(loss),
        log_softmax_(log_softmax),
366 367
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
368 369
        ignore_idx_(ignore_idx) {}

370
  __device__ void operator()(int64_t idx) const {
371
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
372 373 374 375
    int64_t remain = d_ / axis_dim_;
    int64_t idx_n = idx / d_;
    int64_t idx_axis = (idx % d_) / remain;
    int64_t idx_remain = idx % remain;
376
    // labels, loss view as [n, remain]
377
    int64_t idx_lbl = idx_n * remain + idx_remain;
378
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
379
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
380 381
    } else {
      auto softmax = log_softmax_[idx];
382
      log_softmax_[idx] = exp_on_device(softmax);
383
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
384 385 386 387 388 389 390
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
391
  int64_t d_;
392
  int axis_dim_;
S
sneaxiy 已提交
393 394 395 396 397 398
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
399 400
    const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
    int64_t d, int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
401
  constexpr int kMaxBlockDim = 512;
402 403 404 405
  int64_t block_dim = axis_dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(axis_dim)));
  int64_t grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
406 407
  auto stream = ctx.stream();

408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
#ifdef __HIPCC__
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)      \
  case BlockDim: {                                                             \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, d, axis_dim);                                \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, softmax_data, d, axis_dim);                  \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff<T, BlockDim>),      \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, softmax_data, d, axis_dim);                  \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);      \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                            \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(       \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx));     \
    } else {                                                                   \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                    \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx));     \
    }                                                                          \
  } break
#else
430 431 432 433 434 435 436 437 438 439 440 441 442
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
443
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
444
    }                                                                      \
S
sneaxiy 已提交
445
  } break
446
#endif
S
sneaxiy 已提交
447 448 449 450 451 452 453 454 455 456 457 458

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
459 460
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
461 462 463 464 465
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
466
template <typename T>
467 468
static void SoftmaxWithCrossEntropyFusedKernel(
    const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
469
    int64_t n, int64_t d, int axis_dim, gpuStream_t stream) {
S
sneaxiy 已提交
470
  constexpr int kMaxBlockDim = 512;
471 472 473 474
  int64_t block_dim = axis_dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(axis_dim)));
  int64_t grid_dim = n * d / axis_dim;
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
#ifdef __HIPCC__
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, d, axis_dim);                                \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, softmax_data, d, axis_dim);                  \
    hipLaunchKernelGGL(                                                        \
        HIP_KERNEL_NAME(RowReductionForSoftmaxAndCrossEntropy<T, BlockDim>),   \
        dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, labels_data,   \
        loss_data, softmax_data, d, axis_dim);                                 \
    break
#else
490 491 492 493 494 495 496 497 498
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
499
    break
500
#endif
S
sneaxiy 已提交
501 502 503 504 505 506 507 508 509 510 511 512

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
513 514
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
515 516 517 518 519 520
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
521
template <typename T>
Y
Yu Yang 已提交
522
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
523 524
 public:
  void Compute(const framework::ExecutionContext& context) const override {
525 526 527 528
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
C
caoying03 已提交
529
    const Tensor* logits = context.Input<Tensor>("Logits");
530
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
531
    Tensor* softmax = context.Output<Tensor>("Softmax");
532
    Tensor* loss = context.Output<Tensor>("Loss");
533 534 535 536 537

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

538 539
    const int64_t n = SizeToAxis(axis, logits->dims());
    const int64_t d = SizeFromAxis(axis, logits->dims());
540 541 542 543

    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

544 545 546 547 548 549 550
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

S
sneaxiy 已提交
551
    auto soft_label = context.Attr<bool>("soft_label");
552
    auto ignore_index = context.Attr<int>("ignore_index");
553

S
sneaxiy 已提交
554 555 556 557
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
558 559
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
560
    } else {
S
sneaxiy 已提交
561
      if (!context.Attr<bool>("numeric_stable_mode")) {
562 563 564 565 566 567
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
568 569
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
570
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
571
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
572
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
573 574 575 576 577
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
578
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
579
      }
S
sneaxiy 已提交
580
    }
C
caoying03 已提交
581 582 583 584
  }
};

template <typename T>
Y
Yu Yang 已提交
585
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
586 587
 public:
  void Compute(const framework::ExecutionContext& context) const override {
588 589 590 591
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
592 593 594
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
595 596
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
597 598 599 600 601
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
602 603
    T* logit_grad_data = logit_grad->data<T>();

604 605 606 607
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

608 609 610
    const int64_t n = SizeToAxis(axis, logit_grad->dims());
    const int64_t d = SizeFromAxis(axis, logit_grad->dims());
    const int64_t remain = d / axis_dim;
611

612
    int block = 512;
613
    auto stream = context.cuda_device_context().stream();
614
    auto ignore_index = context.Attr<int>("ignore_index");
615
    if (context.Attr<bool>("soft_label")) {
616
      int64_t grid = (n * d + block - 1) / block;
617
      const T* label_data = labels->data<T>();
618
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
619
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
620
    } else {
621
      int64_t grid = (n * remain + block - 1) / block;
C
caoying03 已提交
622
      const int64_t* label_data = labels->data<int64_t>();
623
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
624
          logit_grad_data, label_data, n, d, remain, ignore_index);
625
      int64_t num = n * d;
626 627
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
628
                                           d, remain, label_data, ignore_index);
629
    }
C
caoying03 已提交
630 631 632 633 634 635 636
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
637 638 639 640 641 642 643 644 645 646
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>);
#else
647 648 649 650 651 652 653 654 655
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);
656
#endif