softmax_with_cross_entropy_op.cu 21.0 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
13
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
14
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
15
#include "paddle/fluid/platform/for_range.h"
16

C
caoying03 已提交
17 18 19 20 21
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

22
namespace {
C
caoying03 已提交
23
template <typename T>
24
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
25
                                 const int n, const int d, const int remain,
26
                                 const int ignore_index) {
27
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n * remain;
28
       i += blockDim.x * gridDim.x) {
29 30 31
    int idx_n = i / remain;
    int idx_remain = i % remain;
    int idx = idx_n * d + labels[i] * remain + idx_remain;
B
Bai Yifan 已提交
32 33
    logit_grad[idx] -=
        ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
Y
Yu Yang 已提交
34
  }
35
}
Y
Yu Yang 已提交
36

37 38
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
39
                      const int d, const int remain) {
40 41
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
42 43 44
    int idx_n = i / d;
    int idx_remain = i % remain;
    logit_grad[i] *= loss_grad[idx_n * remain + idx_remain];
45 46 47 48 49 50
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
51 52
                                               const T* labels, const int n,
                                               const int d, const int remain) {
53
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
54 55 56 57 58
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
59
  }
C
caoying03 已提交
60
}
S
sneaxiy 已提交
61

62
}  // namespace
C
caoying03 已提交
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
79 80
  return math::TolerableValue<float>()(logf(x));
}
81
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

125
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
126 127
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
128
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
129
                                          int d, int axis_dim) {
S
sneaxiy 已提交
130 131
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

132 133 134 135 136 137 138 139
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
140

141
  int step = BlockDim * remain;
S
sneaxiy 已提交
142
  T cur_max = logits_data[beg_idx];
143
  beg_idx += step;
S
sneaxiy 已提交
144 145 146 147
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
148
    beg_idx += step;
S
sneaxiy 已提交
149 150 151 152
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

153
  if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
S
sneaxiy 已提交
154 155
}

156
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
157 158
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
159 160
                                                 T* max_data, T* softmax, int d,
                                                 int axis_dim) {
S
sneaxiy 已提交
161 162
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

163 164 165 166 167 168 169 170
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
171 172

  auto block_max = max_data[blockIdx.x];
173
  int step = BlockDim * remain;
S
sneaxiy 已提交
174

175 176 177 178 179 180
  // In numeric stable mode softmax_with_loss, we calc loss with
  // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
  // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
  // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
  // be 1.0 and 0.0, represent prob is 1.0 and 0.0.
  // So there is no need to clip on shift_softmax.
S
sneaxiy 已提交
181
  softmax[beg_idx] = logits_data[beg_idx] - block_max;
182
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
183
  auto idx = beg_idx + step;
S
sneaxiy 已提交
184 185
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
186
    diff_max_sum += exp_on_device(softmax[idx]);
187
    idx += step;
S
sneaxiy 已提交
188 189 190 191
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
192
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
193 194 195 196 197

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
198
  beg_idx += step;
S
sneaxiy 已提交
199 200
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
201
    beg_idx += step;
S
sneaxiy 已提交
202
  }
203 204 205 206

  // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
  // calculate diff_max_sum, __syncthreads() is needed here.
  __syncthreads();
S
sneaxiy 已提交
207
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
208 209
}

210
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
211
template <typename T, int BlockDim>
S
sneaxiy 已提交
212
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
213 214
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
    int axis_dim) {
S
sneaxiy 已提交
215 216
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

217 218 219 220 221 222 223 224
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
225 226 227 228

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
229
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
230
  auto loss = -labels_data[beg_idx] * tmp;
231 232
  int step = BlockDim * remain;
  beg_idx += step;
S
sneaxiy 已提交
233 234
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
235
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
236
    loss -= (labels_data[beg_idx] * tmp);
237
    beg_idx += step;
S
sneaxiy 已提交
238 239 240 241 242 243 244
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
245 246
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
247
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
248
                                          T* log_softmax, int d, int axis_dim)
249
      : labels_(labels),
S
sneaxiy 已提交
250 251
        loss_(loss),
        log_softmax_(log_softmax),
252 253
        d_(d),
        axis_dim_(axis_dim) {}
S
sneaxiy 已提交
254 255

  __device__ void operator()(int idx) const {
256 257 258 259 260 261 262 263
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl]) {
264
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
265 266
    } else {
      auto softmax = log_softmax_[idx];
267
      log_softmax_[idx] = exp_on_device(softmax);
268
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
269 270 271 272 273 274 275
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
276 277
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
278 279 280 281 282
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
283
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
284
                                                       T* loss, T* log_softmax,
285
                                                       int d, int axis_dim,
S
sneaxiy 已提交
286
                                                       int ignore_idx)
287
      : labels_(labels),
S
sneaxiy 已提交
288 289
        loss_(loss),
        log_softmax_(log_softmax),
290 291
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
292 293 294
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
295 296 297 298 299 300 301 302
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
303
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
304 305
    } else {
      auto softmax = log_softmax_[idx];
306
      log_softmax_[idx] = exp_on_device(softmax);
307
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
308 309 310 311 312 313 314
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
315 316
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
317 318 319 320 321 322
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
323 324
    const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d,
    int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
325
  constexpr int kMaxBlockDim = 512;
326
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
327
                      ? kMaxBlockDim
328 329
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
330 331
  auto stream = ctx.stream();

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
          labels_data, loss_data, softmax_data, d, axis_dim));             \
    }                                                                      \
S
sneaxiy 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
366 367 368 369
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
                                               const T* labels_data,
                                               T* softmax_data, T* loss_data,
370
                                               int n, int d, int axis_dim,
S
sneaxiy 已提交
371 372
                                               cudaStream_t stream) {
  constexpr int kMaxBlockDim = 512;
373
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
374
                      ? kMaxBlockDim
375 376 377 378 379 380 381 382 383 384 385 386
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
407
template <typename T>
Y
Yu Yang 已提交
408
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
409 410 411 412 413
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
414
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
415
    Tensor* softmax = context.Output<Tensor>("Softmax");
416
    Tensor* loss = context.Output<Tensor>("Loss");
417 418 419 420 421

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

422 423 424 425 426 427 428
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

429 430 431
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());

S
sneaxiy 已提交
432 433 434 435
    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

    auto soft_label = context.Attr<bool>("soft_label");
436
    auto ignore_index = context.Attr<int>("ignore_index");
437

S
sneaxiy 已提交
438 439 440 441
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
442 443
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
444
    } else {
S
sneaxiy 已提交
445
      if (!context.Attr<bool>("numeric_stable_mode")) {
446 447 448 449 450 451
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
452 453
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
454
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
455
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
456
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
457 458 459 460 461
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
462
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
463
      }
S
sneaxiy 已提交
464
    }
C
caoying03 已提交
465 466 467 468
  }
};

template <typename T>
Y
Yu Yang 已提交
469
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
470 471 472 473
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
474 475 476
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
477 478
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
479 480 481 482 483
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
484 485
    T* logit_grad_data = logit_grad->data<T>();

486 487 488 489 490 491 492
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    const int remain = d / axis_dim;
493

494
    int block = 512;
495
    auto stream = context.cuda_device_context().stream();
496
    auto ignore_index = context.Attr<int>("ignore_index");
497
    if (context.Attr<bool>("soft_label")) {
498
      int grid = (n * d + block - 1) / block;
499
      const T* label_data = labels->data<T>();
500
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
501
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
502
    } else {
503
      int grid = (n * remain + block - 1) / block;
C
caoying03 已提交
504
      const int64_t* label_data = labels->data<int64_t>();
505
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
506 507
          logit_grad_data, label_data, n, d, remain, ignore_index);
      int num = n * d;
508 509
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
510
                                           d, remain);
511
    }
C
caoying03 已提交
512 513 514 515 516 517 518
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
519 520 521 522 523 524 525 526 527
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);