softmax_with_cross_entropy_op.cu 40.3 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
11 12 13 14 15 16 17
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
S
sneaxiy 已提交
18
#include "paddle/fluid/operators/math/cross_entropy.h"
19
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
21
#include "paddle/fluid/platform/for_range.h"
22

C
caoying03 已提交
23 24 25 26 27
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

28
namespace {
C
caoying03 已提交
29
template <typename T>
30
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
31 32 33 34 35 36
                                 const int64_t n, const int64_t d,
                                 const int64_t remain, const int ignore_index) {
  CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) {
    int64_t idx_n = index / remain;
    int64_t idx_remain = index % remain;
    int64_t tmp = labels[index];
37
    if (ignore_index != tmp) {
38
      int64_t idx = idx_n * d + tmp * remain + idx_remain;
39 40
      logit_grad[idx] -= static_cast<T>(1.);
    }
Y
Yu Yang 已提交
41
  }
42
}
Y
Yu Yang 已提交
43

44
template <typename T>
45 46 47 48 49 50 51
__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num,
                      const int64_t d, const int64_t remain,
                      const int64_t* labels, const int ignore_index) {
  CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) {
    int64_t idx_n = index / d;
    int64_t idx_remain = index % remain;
    int64_t idx_lbl = idx_n * remain + idx_remain;
52 53 54 55 56
    if (labels[idx_lbl] == ignore_index) {
      logit_grad[index] = static_cast<T>(0.);
    } else {
      logit_grad[index] *= loss_grad[idx_lbl];
    }
57 58 59 60 61 62
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
63 64 65 66
                                               const T* labels, const int64_t n,
                                               const int64_t d,
                                               const int64_t remain) {
  int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
67
  if (ids < n * d) {
68 69 70
    int64_t idx_n = ids / d;
    int64_t idx_remain = ids % remain;
    int64_t idx_loss = idx_n * remain + idx_remain;
71
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
72
  }
C
caoying03 已提交
73
}
S
sneaxiy 已提交
74

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
template <typename T>
__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
                                                    const T* loss_grad,
                                                    const T* labels,
                                                    const int n, const int d,
                                                    const int remain) {
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]);
  }
}

template <typename T>
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
                                                    const int64_t* labels,
                                                    const int n, const int d,
                                                    const int remain,
                                                    const int ignore_index) {
  CUDA_KERNEL_LOOP(index, n * remain) {
    int idx_n = index / remain;
    int idx_remain = index % remain;
    int tmp = labels[index];
    int idx = idx_n * d + tmp * remain + idx_remain;
    if (ignore_index != tmp) {
      logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
    }
  }
}

template <typename T>
__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
                                          const int num, const int d,
                                          const int remain,
                                          const int64_t* labels,
                                          const int ignore_index) {
  CUDA_KERNEL_LOOP(index, num) {
    int idx_n = index / d;
    int idx_remain = index % remain;
    int idx_lbl = idx_n * remain + idx_remain;
    int k = (index % d) / remain;
    if (labels[idx_lbl] == ignore_index || labels[idx_lbl] != k) {
      logit_grad[index] = static_cast<T>(0.);
    } else {
      logit_grad[index] *= loss_grad[idx_lbl];
    }
  }
}

126
}  // namespace
C
caoying03 已提交
127

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
143 144
  return math::TolerableValue<float>()(logf(x));
}
145
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
T
tianshuo78520a 已提交
167
Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
S
sneaxiy 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

189
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
190 191
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
192
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
193
                                          int64_t d, int axis_dim) {
S
sneaxiy 已提交
194 195
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

196 197 198
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
199 200 201 202 203
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
204

205
  int64_t step = BlockDim * remain;
S
sneaxiy 已提交
206
  T cur_max = logits_data[beg_idx];
207
  beg_idx += step;
S
sneaxiy 已提交
208 209 210 211
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
212
    beg_idx += step;
S
sneaxiy 已提交
213 214 215 216
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

217
  if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
S
sneaxiy 已提交
218 219
}

220
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
221 222
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
223 224
                                                 T* max_data, T* softmax,
                                                 int64_t d, int axis_dim) {
S
sneaxiy 已提交
225 226
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

227 228 229
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
230 231 232 233 234
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
235 236

  auto block_max = max_data[blockIdx.x];
237
  int64_t step = BlockDim * remain;
S
sneaxiy 已提交
238

239 240 241 242 243 244
  // In numeric stable mode softmax_with_loss, we calc loss with
  // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
  // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
  // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
  // be 1.0 and 0.0, represent prob is 1.0 and 0.0.
  // So there is no need to clip on shift_softmax.
S
sneaxiy 已提交
245
  softmax[beg_idx] = logits_data[beg_idx] - block_max;
246
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
247
  auto idx = beg_idx + step;
S
sneaxiy 已提交
248 249
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
250
    diff_max_sum += exp_on_device(softmax[idx]);
251
    idx += step;
S
sneaxiy 已提交
252 253 254 255
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
256
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
257 258 259 260 261

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
262
  beg_idx += step;
S
sneaxiy 已提交
263 264
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
265
    beg_idx += step;
S
sneaxiy 已提交
266
  }
267 268 269 270

  // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
  // calculate diff_max_sum, __syncthreads() is needed here.
  __syncthreads();
S
sneaxiy 已提交
271
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
272 273
}

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
#ifdef __HIPCC__  // @{ HIP Seperate Kernel for RowReductionForDiffMaxSum
// Note(qili93): HIP do not support return in kernel, need to seperate
// RowReductionForDiffMaxSum into two kernels below
template <typename T, int BlockDim>
static __global__ void RowReductionForSum(const T* logits_data, T* max_data,
                                          T* softmax, int64_t d, int axis_dim) {
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;

  auto block_max = max_data[blockIdx.x];
  int64_t step = BlockDim * remain;

  softmax[beg_idx] = logits_data[beg_idx] - block_max;
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
  auto idx = beg_idx + step;
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
    diff_max_sum += exp_on_device(softmax[idx]);
    idx += step;
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
}

template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiff(const T* logits_data, T* max_data,
                                           T* softmax, int d, int axis_dim) {
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
  int step = BlockDim * remain;

  T diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
  beg_idx += step;
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
    beg_idx += step;
  }

  __syncthreads();
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
}
#endif  // @} End HIP Seperate Kernel for RowReductionForDiffMaxSum

328
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
329
template <typename T, int BlockDim>
S
sneaxiy 已提交
330
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
331 332
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
    int64_t d, int axis_dim) {
S
sneaxiy 已提交
333 334
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

335 336 337
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
338 339 340 341 342
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
343 344 345 346

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
347
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
348
  auto loss = -labels_data[beg_idx] * tmp;
349
  int64_t step = BlockDim * remain;
350
  beg_idx += step;
S
sneaxiy 已提交
351 352
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
353
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
354
    loss -= (labels_data[beg_idx] * tmp);
355
    beg_idx += step;
S
sneaxiy 已提交
356 357 358 359 360 361
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim>
static __global__ void RowReductionForCrossEntropy(const T* logits_data,
                                                   const T* labels_data,
                                                   T* loss_data, int d,
                                                   int axis_dim) {
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = log_on_device(logits_data[beg_idx]);  // when not with softmax,
                                                   // softmax is stored in
                                                   // logits_data
  auto loss = -labels_data[beg_idx] * tmp;
  int step = BlockDim * remain;
  beg_idx += step;
  while (beg_idx < end_idx) {
    tmp = log_on_device(logits_data[beg_idx]);  // when not with softmax,
                                                // softmax is stored in
                                                // logits_data
    loss -= (labels_data[beg_idx] * tmp);
    beg_idx += step;
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
struct HardLabelCrossEntropyFunctor {
 public:
  HardLabelCrossEntropyFunctor(const int64_t* labels, T* loss,
                               const T* logits_data, int d, int axis_dim)
      : labels_(labels),
        loss_(loss),
        logits_data_(logits_data),
        d_(d),
        axis_dim_(axis_dim) {}

  __device__ void operator()(int idx) const {
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    // It also would ignore labels not in range(class_num).
    if (idx_axis != labels_[idx_lbl]) {
    } else {
      loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  const T* logits_data_;
  int d_;
  int axis_dim_;
};

template <typename T>
struct HardLabelCrossEntropyFunctorWithIgnoreIdx {
 public:
  HardLabelCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss,
                                            const T* logits_data, int d,
                                            int axis_dim, int ignore_idx)
      : labels_(labels),
        loss_(loss),
        logits_data_(logits_data),
        d_(d),
        axis_dim_(axis_dim),
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;

455
    if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) {
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
      loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  const T* logits_data_;
  int d_;
  int axis_dim_;
  int ignore_idx_;
};

template <typename T>
static void HardLabelCrossEntropy(const platform::CUDADeviceContext& ctx,
                                  const T* logits_data,
                                  const int64_t* labels_data, T* loss_data,
                                  int n, int d, int axis_dim, int ignore_idx) {
  constexpr int kMaxBlockDim = 512;
  int block_dim = axis_dim >= kMaxBlockDim
                      ? kMaxBlockDim
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;
  auto stream = ctx.stream();

#define CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                \
  case BlockDim: {                                                          \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);   \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                         \
      for_range(HardLabelCrossEntropyFunctorWithIgnoreIdx<T>(               \
          labels_data, loss_data, logits_data, d, axis_dim, ignore_idx));   \
    } else {                                                                \
      for_range(HardLabelCrossEntropyFunctor<T>(labels_data, loss_data,     \
                                                logits_data, d, axis_dim)); \
    }                                                                       \
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
511
template <typename T>
S
sneaxiy 已提交
512 513
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
514
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
515
                                          T* log_softmax, int64_t d,
516
                                          int axis_dim, int ignore_idx)
517
      : labels_(labels),
S
sneaxiy 已提交
518 519
        loss_(loss),
        log_softmax_(log_softmax),
520
        d_(d),
521 522
        axis_dim_(axis_dim),
        ignore_idx_(ignore_idx) {}
S
sneaxiy 已提交
523

524
  __device__ void operator()(int64_t idx) const {
525
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
526 527 528 529
    int64_t remain = d_ / axis_dim_;
    int64_t idx_n = idx / d_;
    int64_t idx_axis = (idx % d_) / remain;
    int64_t idx_remain = idx % remain;
530
    // labels, loss view as [n, remain]
531
    int64_t idx_lbl = idx_n * remain + idx_remain;
532 533 534
    PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
                       labels_[idx_lbl] == ignore_idx_,
                   "The value of label[%ld] expected >= 0 and < %ld, or == %d,"
G
Guanghua Yu 已提交
535
                   "but got %ld. Please check input value.",
536
                   idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
537
    // It also would ignore labels not in range(class_num).
538
    if (idx_axis != labels_[idx_lbl]) {
539
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
540 541
    } else {
      auto softmax = log_softmax_[idx];
542
      log_softmax_[idx] = exp_on_device(softmax);
543
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
544 545 546 547 548 549 550
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
551
  int64_t d_;
552
  int axis_dim_;
553
  int ignore_idx_;
S
sneaxiy 已提交
554 555 556 557 558
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
559
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
560
                                                       T* loss, T* log_softmax,
561
                                                       int64_t d, int axis_dim,
S
sneaxiy 已提交
562
                                                       int ignore_idx)
563
      : labels_(labels),
S
sneaxiy 已提交
564 565
        loss_(loss),
        log_softmax_(log_softmax),
566 567
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
568 569
        ignore_idx_(ignore_idx) {}

570
  __device__ void operator()(int64_t idx) const {
571
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
572 573 574 575
    int64_t remain = d_ / axis_dim_;
    int64_t idx_n = idx / d_;
    int64_t idx_axis = (idx % d_) / remain;
    int64_t idx_remain = idx % remain;
576
    // labels, loss view as [n, remain]
577
    int64_t idx_lbl = idx_n * remain + idx_remain;
578
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
579
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
580 581
    } else {
      auto softmax = log_softmax_[idx];
582
      log_softmax_[idx] = exp_on_device(softmax);
583
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
584 585 586 587 588 589 590
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
591
  int64_t d_;
592
  int axis_dim_;
S
sneaxiy 已提交
593 594 595 596 597 598
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
599 600
    const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
    int64_t d, int axis_dim, int ignore_idx) {
601 602 603 604
#ifdef __HIPCC__
  // HIP platform will have loss nan if dim size > 256
  constexpr int kMaxBlockDim = 256;
#else
S
sneaxiy 已提交
605
  constexpr int kMaxBlockDim = 512;
606
#endif
607 608 609 610
  int64_t block_dim = axis_dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(axis_dim)));
  int64_t grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
611 612
  auto stream = ctx.stream();

613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
#ifdef __HIPCC__
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)      \
  case BlockDim: {                                                             \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, d, axis_dim);                                \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, softmax_data, d, axis_dim);                  \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff<T, BlockDim>),      \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, softmax_data, d, axis_dim);                  \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);      \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                            \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(       \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx));     \
    } else {                                                                   \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                    \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx));     \
    }                                                                          \
  } break
#else
635 636 637 638 639 640 641 642 643 644 645 646 647
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
648
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
649
    }                                                                      \
S
sneaxiy 已提交
650
  } break
651
#endif
S
sneaxiy 已提交
652 653 654 655 656 657 658 659 660 661 662 663

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
664 665
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
666 667 668 669 670
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
671
template <typename T>
672 673
static void SoftmaxWithCrossEntropyFusedKernel(
    const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
674
    int64_t n, int64_t d, int axis_dim, gpuStream_t stream) {
S
sneaxiy 已提交
675
  constexpr int kMaxBlockDim = 512;
676 677 678 679
  int64_t block_dim = axis_dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(axis_dim)));
  int64_t grid_dim = n * d / axis_dim;
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
#ifdef __HIPCC__
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, d, axis_dim);                                \
    hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>),       \
                       dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
                       loss_data, softmax_data, d, axis_dim);                  \
    hipLaunchKernelGGL(                                                        \
        HIP_KERNEL_NAME(RowReductionForSoftmaxAndCrossEntropy<T, BlockDim>),   \
        dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, labels_data,   \
        loss_data, softmax_data, d, axis_dim);                                 \
    break
#else
695 696 697 698 699 700 701 702 703
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
704
    break
705
#endif
S
sneaxiy 已提交
706 707 708 709 710 711 712 713 714 715 716 717

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
718 719
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
720 721 722 723 724 725
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

726 727 728 729
// not with softmax
template <typename T>
static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data,
                                    T* loss_data, int n, int d, int axis_dim,
730
                                    gpuStream_t stream) {
731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
  constexpr int kMaxBlockDim = 512;
  int block_dim = axis_dim >= kMaxBlockDim
                      ? kMaxBlockDim
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                \
  case BlockDim:                                                              \
    RowReductionForCrossEntropy<T,                                            \
                                BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, labels_data, loss_data, d, axis_dim);                    \
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
763
template <typename T>
Y
Yu Yang 已提交
764
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
765 766
 public:
  void Compute(const framework::ExecutionContext& context) const override {
767 768 769 770
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
    const bool softmax_switch = context.Attr<bool>("softmax_switch");

    // do not with softmax op, and input is softmax
    if (!softmax_switch) {
      const Tensor* softmax = context.Input<Tensor>("Logits");
      const Tensor* labels = context.Input<Tensor>("Label");
      Tensor* softmax_out = context.Output<Tensor>("Softmax");
      Tensor* loss = context.Output<Tensor>("Loss");

      const int rank = softmax->dims().size();
      const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
      int axis_dim = softmax->dims()[axis];

      const int n = SizeToAxis(axis, softmax->dims());
      const int d = SizeFromAxis(axis, softmax->dims());

      auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace());
      auto* loss_data = loss->mutable_data<T>(context.GetPlace());

790 791
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
      if (axis_dim == 1) {
        set_constant(context.cuda_device_context(), softmax_out,
                     static_cast<T>(1));
        return;
      }

      auto soft_label = context.Attr<bool>("soft_label");
      auto ignore_index = context.Attr<int>("ignore_index");

      Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
      softmax_2d.ShareDataWith(*softmax).Resize({n, d});
      labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
      loss_2d.ShareDataWith(*loss).Resize({n, 1});
      softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});

      // math::CrossEntropyFunctor support axis is the last
      if (axis == -1) {
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
            soft_label, ignore_index, axis_dim);
        return;
      }

      // if axis is not the last, we need a new impliment
      if (soft_label) {
        auto* logits_data = softmax->data<T>();
        auto* labels_data = labels->data<T>();
        CrossEntropyFusedKernel(logits_data, labels_data, loss_data, n, d,
                                axis_dim,
                                context.cuda_device_context().stream());
      } else {  // HardLabel
        auto* logits_data = softmax->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelCrossEntropy<T>(context.cuda_device_context(), logits_data,
                                 labels_data, loss_data, n, d, axis_dim,
                                 ignore_index);
      }

      // cause of input is softmax
      // copy to output softmax, directly
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), softmax_out);

      return;
    }

C
caoying03 已提交
838
    const Tensor* logits = context.Input<Tensor>("Logits");
839
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
840
    Tensor* softmax = context.Output<Tensor>("Softmax");
841
    Tensor* loss = context.Output<Tensor>("Loss");
842 843 844 845 846

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

847 848
    const int64_t n = SizeToAxis(axis, logits->dims());
    const int64_t d = SizeFromAxis(axis, logits->dims());
849 850 851 852

    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

853 854 855 856 857 858 859
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

S
sneaxiy 已提交
860
    auto soft_label = context.Attr<bool>("soft_label");
861
    auto ignore_index = context.Attr<int>("ignore_index");
862

S
sneaxiy 已提交
863 864 865 866
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
867 868
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
869
    } else {
S
sneaxiy 已提交
870
      if (!context.Attr<bool>("numeric_stable_mode")) {
871 872 873 874 875 876
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
877 878
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
879
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
880
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
881
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
882 883 884 885 886
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
887
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
888
      }
S
sneaxiy 已提交
889
    }
C
caoying03 已提交
890 891 892 893
  }
};

template <typename T>
Y
Yu Yang 已提交
894
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
895 896
 public:
  void Compute(const framework::ExecutionContext& context) const override {
897 898 899 900
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
901 902 903
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
904 905
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
906 907 908 909 910
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
911 912
    T* logit_grad_data = logit_grad->data<T>();

913 914 915 916
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

917 918 919
    const int64_t n = SizeToAxis(axis, logit_grad->dims());
    const int64_t d = SizeFromAxis(axis, logit_grad->dims());
    const int64_t remain = d / axis_dim;
920

921
    int block = 512;
922
    auto stream = context.cuda_device_context().stream();
923
    auto ignore_index = context.Attr<int>("ignore_index");
924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
    auto softmax_switch = context.Attr<bool>("softmax_switch");

    // do not with softmax op, and input is softmax
    if (!softmax_switch) {
      if (context.Attr<bool>("soft_label")) {
        int grid = (n * d + block - 1) / block;
        const T* label_data = labels->data<T>();
        SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
            logit_grad_data, loss_grad_data, label_data, n, d, remain);
      } else {
        Tensor logits_grad_2d;
        logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
        int grid = (n * remain + block - 1) / block;
        const int64_t* label_data = labels->data<int64_t>();
        HardLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
            logit_grad_data, label_data, n, d, remain, ignore_index);
        int num = n * d;
        grid = (num + block - 1) / block;
        ScaleCrossEntropyGradient<T><<<grid, block, 0, stream>>>(
            logit_grad_data, loss_grad_data, num, d, remain, label_data,
            ignore_index);
      }

      return;
    }

    // with softmax, continue

952
    if (context.Attr<bool>("soft_label")) {
953
      int64_t grid = (n * d + block - 1) / block;
954
      const T* label_data = labels->data<T>();
955
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
956
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
957
    } else {
958
      int64_t grid = (n * remain + block - 1) / block;
C
caoying03 已提交
959
      const int64_t* label_data = labels->data<int64_t>();
960
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
961
          logit_grad_data, label_data, n, d, remain, ignore_index);
962
      int64_t num = n * d;
963 964
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
965
                                           d, remain, label_data, ignore_index);
966
    }
C
caoying03 已提交
967 968 969 970 971 972 973
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
974 975 976 977 978 979 980 981 982 983
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>);
#else
984 985 986 987 988 989 990 991 992
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);
993
#endif