softmax_with_cross_entropy_op.cu 21.5 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
13
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
14
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
15
#include "paddle/fluid/platform/for_range.h"
16

C
caoying03 已提交
17 18 19 20 21
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

22
namespace {
C
caoying03 已提交
23
template <typename T>
24
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
25
                                 const int n, const int d, const int remain,
26
                                 const int ignore_index) {
27 28 29
  CUDA_KERNEL_LOOP(index, n * remain) {
    int idx_n = index / remain;
    int idx_remain = index % remain;
30 31 32 33 34
    int tmp = labels[index];
    if (ignore_index != tmp) {
      int idx = idx_n * d + tmp * remain + idx_remain;
      logit_grad[idx] -= static_cast<T>(1.);
    }
Y
Yu Yang 已提交
35
  }
36
}
Y
Yu Yang 已提交
37

38 39
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
40 41
                      const int d, const int remain, const int64_t* labels,
                      const int ignore_index) {
42 43 44
  CUDA_KERNEL_LOOP(index, num) {
    int idx_n = index / d;
    int idx_remain = index % remain;
45 46 47 48 49 50
    int idx_lbl = idx_n * remain + idx_remain;
    if (labels[idx_lbl] == ignore_index) {
      logit_grad[index] = static_cast<T>(0.);
    } else {
      logit_grad[index] *= loss_grad[idx_lbl];
    }
51 52 53 54 55 56
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
57 58
                                               const T* labels, const int n,
                                               const int d, const int remain) {
59
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
60 61 62 63 64
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
65
  }
C
caoying03 已提交
66
}
S
sneaxiy 已提交
67

68
}  // namespace
C
caoying03 已提交
69

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
85 86
  return math::TolerableValue<float>()(logf(x));
}
87
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
T
tianshuo78520a 已提交
109
Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
S
sneaxiy 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

131
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
132 133
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
134
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
135
                                          int d, int axis_dim) {
S
sneaxiy 已提交
136 137
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

138 139 140 141 142 143 144 145
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
146

147
  int step = BlockDim * remain;
S
sneaxiy 已提交
148
  T cur_max = logits_data[beg_idx];
149
  beg_idx += step;
S
sneaxiy 已提交
150 151 152 153
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
154
    beg_idx += step;
S
sneaxiy 已提交
155 156 157 158
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

159
  if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
S
sneaxiy 已提交
160 161
}

162
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
163 164
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
165 166
                                                 T* max_data, T* softmax, int d,
                                                 int axis_dim) {
S
sneaxiy 已提交
167 168
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

169 170 171 172 173 174 175 176
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
177 178

  auto block_max = max_data[blockIdx.x];
179
  int step = BlockDim * remain;
S
sneaxiy 已提交
180

181 182 183 184 185 186
  // In numeric stable mode softmax_with_loss, we calc loss with
  // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
  // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
  // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
  // be 1.0 and 0.0, represent prob is 1.0 and 0.0.
  // So there is no need to clip on shift_softmax.
S
sneaxiy 已提交
187
  softmax[beg_idx] = logits_data[beg_idx] - block_max;
188
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
189
  auto idx = beg_idx + step;
S
sneaxiy 已提交
190 191
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
192
    diff_max_sum += exp_on_device(softmax[idx]);
193
    idx += step;
S
sneaxiy 已提交
194 195 196 197
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
198
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
199 200 201 202 203

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
204
  beg_idx += step;
S
sneaxiy 已提交
205 206
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
207
    beg_idx += step;
S
sneaxiy 已提交
208
  }
209 210 211 212

  // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
  // calculate diff_max_sum, __syncthreads() is needed here.
  __syncthreads();
S
sneaxiy 已提交
213
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
214 215
}

216
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
217
template <typename T, int BlockDim>
S
sneaxiy 已提交
218
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
219 220
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
    int axis_dim) {
S
sneaxiy 已提交
221 222
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

223 224 225 226 227 228 229 230
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
231 232 233 234

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
235
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
236
  auto loss = -labels_data[beg_idx] * tmp;
237 238
  int step = BlockDim * remain;
  beg_idx += step;
S
sneaxiy 已提交
239 240
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
241
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
242
    loss -= (labels_data[beg_idx] * tmp);
243
    beg_idx += step;
S
sneaxiy 已提交
244 245 246 247 248 249 250
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
251 252
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
253
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
254
                                          T* log_softmax, int d, int axis_dim)
255
      : labels_(labels),
S
sneaxiy 已提交
256 257
        loss_(loss),
        log_softmax_(log_softmax),
258 259
        d_(d),
        axis_dim_(axis_dim) {}
S
sneaxiy 已提交
260 261

  __device__ void operator()(int idx) const {
262 263 264 265 266 267 268
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
269
    // It also would ignore labels not in range(class_num).
270
    if (idx_axis != labels_[idx_lbl]) {
271
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
272 273
    } else {
      auto softmax = log_softmax_[idx];
274
      log_softmax_[idx] = exp_on_device(softmax);
275
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
276 277 278 279 280 281 282
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
283 284
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
285 286 287 288 289
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
290
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
291
                                                       T* loss, T* log_softmax,
292
                                                       int d, int axis_dim,
S
sneaxiy 已提交
293
                                                       int ignore_idx)
294
      : labels_(labels),
S
sneaxiy 已提交
295 296
        loss_(loss),
        log_softmax_(log_softmax),
297 298
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
299 300 301
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
302 303 304 305 306 307 308 309
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
310
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
311 312
    } else {
      auto softmax = log_softmax_[idx];
313
      log_softmax_[idx] = exp_on_device(softmax);
314
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
315 316 317 318 319 320 321
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
322 323
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
324 325 326 327 328 329
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
330 331
    const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d,
    int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
332
  constexpr int kMaxBlockDim = 512;
333
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
334
                      ? kMaxBlockDim
335 336
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
337 338
  auto stream = ctx.stream();

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
          labels_data, loss_data, softmax_data, d, axis_dim));             \
    }                                                                      \
S
sneaxiy 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
367 368
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
369 370 371 372 373
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
374 375 376 377
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
                                               const T* labels_data,
                                               T* softmax_data, T* loss_data,
378
                                               int n, int d, int axis_dim,
S
sneaxiy 已提交
379 380
                                               cudaStream_t stream) {
  constexpr int kMaxBlockDim = 512;
381
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
382
                      ? kMaxBlockDim
383 384 385 386 387 388 389 390 391 392 393 394
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
395 396 397 398 399 400 401 402 403 404 405 406 407
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
408 409
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
410 411 412 413 414 415
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
416
template <typename T>
Y
Yu Yang 已提交
417
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
418 419
 public:
  void Compute(const framework::ExecutionContext& context) const override {
420 421 422 423
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
C
caoying03 已提交
424
    const Tensor* logits = context.Input<Tensor>("Logits");
425
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
426
    Tensor* softmax = context.Output<Tensor>("Softmax");
427
    Tensor* loss = context.Output<Tensor>("Loss");
428 429 430 431 432

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

433 434 435 436 437 438
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());

    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

439 440 441 442 443 444 445
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

S
sneaxiy 已提交
446
    auto soft_label = context.Attr<bool>("soft_label");
447
    auto ignore_index = context.Attr<int>("ignore_index");
448

S
sneaxiy 已提交
449 450 451 452
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
453 454
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
455
    } else {
S
sneaxiy 已提交
456
      if (!context.Attr<bool>("numeric_stable_mode")) {
457 458 459 460 461 462
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
463 464
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
465
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
466
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
467
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
468 469 470 471 472
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
473
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
474
      }
S
sneaxiy 已提交
475
    }
C
caoying03 已提交
476 477 478 479
  }
};

template <typename T>
Y
Yu Yang 已提交
480
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
481 482
 public:
  void Compute(const framework::ExecutionContext& context) const override {
483 484 485 486
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
487 488 489
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
490 491
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
492 493 494 495 496
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
497 498
    T* logit_grad_data = logit_grad->data<T>();

499 500 501 502 503 504 505
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    const int remain = d / axis_dim;
506

507
    int block = 512;
508
    auto stream = context.cuda_device_context().stream();
509
    auto ignore_index = context.Attr<int>("ignore_index");
510
    if (context.Attr<bool>("soft_label")) {
511
      int grid = (n * d + block - 1) / block;
512
      const T* label_data = labels->data<T>();
513
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
514
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
515
    } else {
516
      int grid = (n * remain + block - 1) / block;
C
caoying03 已提交
517
      const int64_t* label_data = labels->data<int64_t>();
518
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
519 520
          logit_grad_data, label_data, n, d, remain, ignore_index);
      int num = n * d;
521 522
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
523
                                           d, remain, label_data, ignore_index);
524
    }
C
caoying03 已提交
525 526 527 528 529 530 531
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
532 533 534 535 536 537 538 539 540
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);