softmax_with_cross_entropy_op.cu 21.8 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
13
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
14
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
15
#include "paddle/fluid/platform/for_range.h"
16

C
caoying03 已提交
17 18 19 20 21
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

22
namespace {
C
caoying03 已提交
23
template <typename T>
24
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
25 26 27 28 29 30
                                 const int64_t n, const int64_t d,
                                 const int64_t remain, const int ignore_index) {
  CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) {
    int64_t idx_n = index / remain;
    int64_t idx_remain = index % remain;
    int64_t tmp = labels[index];
31
    if (ignore_index != tmp) {
32
      int64_t idx = idx_n * d + tmp * remain + idx_remain;
33 34
      logit_grad[idx] -= static_cast<T>(1.);
    }
Y
Yu Yang 已提交
35
  }
36
}
Y
Yu Yang 已提交
37

38
template <typename T>
39 40 41 42 43 44 45
__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num,
                      const int64_t d, const int64_t remain,
                      const int64_t* labels, const int ignore_index) {
  CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) {
    int64_t idx_n = index / d;
    int64_t idx_remain = index % remain;
    int64_t idx_lbl = idx_n * remain + idx_remain;
46 47 48 49 50
    if (labels[idx_lbl] == ignore_index) {
      logit_grad[index] = static_cast<T>(0.);
    } else {
      logit_grad[index] *= loss_grad[idx_lbl];
    }
51 52 53 54 55 56
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
57 58 59 60
                                               const T* labels, const int64_t n,
                                               const int64_t d,
                                               const int64_t remain) {
  int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
61
  if (ids < n * d) {
62 63 64
    int64_t idx_n = ids / d;
    int64_t idx_remain = ids % remain;
    int64_t idx_loss = idx_n * remain + idx_remain;
65
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
66
  }
C
caoying03 已提交
67
}
S
sneaxiy 已提交
68

69
}  // namespace
C
caoying03 已提交
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
86 87
  return math::TolerableValue<float>()(logf(x));
}
88
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
T
tianshuo78520a 已提交
110
Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
S
sneaxiy 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

132
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
133 134
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
135
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
136
                                          int64_t d, int axis_dim) {
S
sneaxiy 已提交
137 138
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

139 140 141
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
142 143 144 145 146
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
147

148
  int64_t step = BlockDim * remain;
S
sneaxiy 已提交
149
  T cur_max = logits_data[beg_idx];
150
  beg_idx += step;
S
sneaxiy 已提交
151 152 153 154
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
155
    beg_idx += step;
S
sneaxiy 已提交
156 157 158 159
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

160
  if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
S
sneaxiy 已提交
161 162
}

163
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
164 165
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
166 167
                                                 T* max_data, T* softmax,
                                                 int64_t d, int axis_dim) {
S
sneaxiy 已提交
168 169
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

170 171 172
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
173 174 175 176 177
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
178 179

  auto block_max = max_data[blockIdx.x];
180
  int64_t step = BlockDim * remain;
S
sneaxiy 已提交
181

182 183 184 185 186 187
  // In numeric stable mode softmax_with_loss, we calc loss with
  // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
  // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
  // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
  // be 1.0 and 0.0, represent prob is 1.0 and 0.0.
  // So there is no need to clip on shift_softmax.
S
sneaxiy 已提交
188
  softmax[beg_idx] = logits_data[beg_idx] - block_max;
189
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
190
  auto idx = beg_idx + step;
S
sneaxiy 已提交
191 192
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
193
    diff_max_sum += exp_on_device(softmax[idx]);
194
    idx += step;
S
sneaxiy 已提交
195 196 197 198
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
199
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
200 201 202 203 204

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
205
  beg_idx += step;
S
sneaxiy 已提交
206 207
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
208
    beg_idx += step;
S
sneaxiy 已提交
209
  }
210 211 212 213

  // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
  // calculate diff_max_sum, __syncthreads() is needed here.
  __syncthreads();
S
sneaxiy 已提交
214
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
215 216
}

217
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
218
template <typename T, int BlockDim>
S
sneaxiy 已提交
219
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
220 221
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
    int64_t d, int axis_dim) {
S
sneaxiy 已提交
222 223
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

224 225 226
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
227 228 229 230 231
  int64_t remain = d / axis_dim;
  int64_t idx_n = blockIdx.x / remain;
  int64_t idx_remain = blockIdx.x % remain;
  int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int64_t end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
232 233 234 235

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
236
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
237
  auto loss = -labels_data[beg_idx] * tmp;
238
  int64_t step = BlockDim * remain;
239
  beg_idx += step;
S
sneaxiy 已提交
240 241
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
242
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
243
    loss -= (labels_data[beg_idx] * tmp);
244
    beg_idx += step;
S
sneaxiy 已提交
245 246 247 248 249 250 251
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
252 253
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
254
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
255 256
                                          T* log_softmax, int64_t d,
                                          int axis_dim)
257
      : labels_(labels),
S
sneaxiy 已提交
258 259
        loss_(loss),
        log_softmax_(log_softmax),
260 261
        d_(d),
        axis_dim_(axis_dim) {}
S
sneaxiy 已提交
262

263
  __device__ void operator()(int64_t idx) const {
264
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
265 266 267 268
    int64_t remain = d_ / axis_dim_;
    int64_t idx_n = idx / d_;
    int64_t idx_axis = (idx % d_) / remain;
    int64_t idx_remain = idx % remain;
269
    // labels, loss view as [n, remain]
270
    int64_t idx_lbl = idx_n * remain + idx_remain;
271
    // It also would ignore labels not in range(class_num).
272
    if (idx_axis != labels_[idx_lbl]) {
273
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
274 275
    } else {
      auto softmax = log_softmax_[idx];
276
      log_softmax_[idx] = exp_on_device(softmax);
277
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
278 279 280 281 282 283 284
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
285
  int64_t d_;
286
  int axis_dim_;
S
sneaxiy 已提交
287 288 289 290 291
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
292
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
293
                                                       T* loss, T* log_softmax,
294
                                                       int64_t d, int axis_dim,
S
sneaxiy 已提交
295
                                                       int ignore_idx)
296
      : labels_(labels),
S
sneaxiy 已提交
297 298
        loss_(loss),
        log_softmax_(log_softmax),
299 300
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
301 302
        ignore_idx_(ignore_idx) {}

303
  __device__ void operator()(int64_t idx) const {
304
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
305 306 307 308
    int64_t remain = d_ / axis_dim_;
    int64_t idx_n = idx / d_;
    int64_t idx_axis = (idx % d_) / remain;
    int64_t idx_remain = idx % remain;
309
    // labels, loss view as [n, remain]
310
    int64_t idx_lbl = idx_n * remain + idx_remain;
311
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
312
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
313 314
    } else {
      auto softmax = log_softmax_[idx];
315
      log_softmax_[idx] = exp_on_device(softmax);
316
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
317 318 319 320 321 322 323
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
324
  int64_t d_;
325
  int axis_dim_;
S
sneaxiy 已提交
326 327 328 329 330 331
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
332 333
    const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
    int64_t d, int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
334
  constexpr int kMaxBlockDim = 512;
335 336 337 338
  int64_t block_dim = axis_dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(axis_dim)));
  int64_t grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
339 340
  auto stream = ctx.stream();

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
          labels_data, loss_data, softmax_data, d, axis_dim));             \
    }                                                                      \
S
sneaxiy 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
369 370
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
371 372 373 374 375
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
376
template <typename T>
377 378 379
static void SoftmaxWithCrossEntropyFusedKernel(
    const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
    int64_t n, int64_t d, int axis_dim, cudaStream_t stream) {
S
sneaxiy 已提交
380
  constexpr int kMaxBlockDim = 512;
381 382 383 384
  int64_t block_dim = axis_dim >= kMaxBlockDim
                          ? kMaxBlockDim
                          : (1 << static_cast<int>(std::log2(axis_dim)));
  int64_t grid_dim = n * d / axis_dim;
385 386 387 388 389 390 391 392 393 394

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
395 396 397 398 399 400 401 402 403 404 405 406 407
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
408 409
      PADDLE_THROW(platform::errors::Unavailable(
          "Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
S
sneaxiy 已提交
410 411 412 413 414 415
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
416
template <typename T>
Y
Yu Yang 已提交
417
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
418 419
 public:
  void Compute(const framework::ExecutionContext& context) const override {
420 421 422 423
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
C
caoying03 已提交
424
    const Tensor* logits = context.Input<Tensor>("Logits");
425
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
426
    Tensor* softmax = context.Output<Tensor>("Softmax");
427
    Tensor* loss = context.Output<Tensor>("Loss");
428 429 430 431 432

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

433 434
    const int64_t n = SizeToAxis(axis, logits->dims());
    const int64_t d = SizeFromAxis(axis, logits->dims());
435 436 437 438

    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

439 440 441 442 443 444 445
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

S
sneaxiy 已提交
446
    auto soft_label = context.Attr<bool>("soft_label");
447
    auto ignore_index = context.Attr<int>("ignore_index");
448

S
sneaxiy 已提交
449 450 451 452
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
453 454
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
455
    } else {
S
sneaxiy 已提交
456
      if (!context.Attr<bool>("numeric_stable_mode")) {
457 458 459 460 461 462
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
463 464
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
465
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
466
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
467
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
468 469 470 471 472
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
473
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
474
      }
S
sneaxiy 已提交
475
    }
C
caoying03 已提交
476 477 478 479
  }
};

template <typename T>
Y
Yu Yang 已提交
480
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
481 482
 public:
  void Compute(const framework::ExecutionContext& context) const override {
483 484 485 486
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::Unavailable("softmax_with_cross_entropy operator's "
                                      "CUDA kernel only runs on GPU device."));
487 488 489
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
490 491
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
492 493 494 495 496
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
497 498
    T* logit_grad_data = logit_grad->data<T>();

499 500 501 502
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

503 504 505
    const int64_t n = SizeToAxis(axis, logit_grad->dims());
    const int64_t d = SizeFromAxis(axis, logit_grad->dims());
    const int64_t remain = d / axis_dim;
506

507
    int block = 512;
508
    auto stream = context.cuda_device_context().stream();
509
    auto ignore_index = context.Attr<int>("ignore_index");
510
    if (context.Attr<bool>("soft_label")) {
511
      int64_t grid = (n * d + block - 1) / block;
512
      const T* label_data = labels->data<T>();
513
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
514
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
515
    } else {
516
      int64_t grid = (n * remain + block - 1) / block;
C
caoying03 已提交
517
      const int64_t* label_data = labels->data<int64_t>();
518
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
519
          logit_grad_data, label_data, n, d, remain, ignore_index);
520
      int64_t num = n * d;
521 522
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
523
                                           d, remain, label_data, ignore_index);
524
    }
C
caoying03 已提交
525 526 527 528 529 530 531
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
532 533 534 535 536 537 538 539 540
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);