softmax_with_cross_entropy_op.cu 21.1 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
Y
Yi Wang 已提交
13
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
14
#include "paddle/fluid/platform/for_range.h"
15

C
caoying03 已提交
16 17 18 19 20
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

21
namespace {
C
caoying03 已提交
22
template <typename T>
23
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
24
                                 const int n, const int d, const int remain,
25
                                 const int ignore_index) {
26
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n * remain;
27
       i += blockDim.x * gridDim.x) {
28 29 30
    int idx_n = i / remain;
    int idx_remain = i % remain;
    int idx = idx_n * d + labels[i] * remain + idx_remain;
B
Bai Yifan 已提交
31 32
    logit_grad[idx] -=
        ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
Y
Yu Yang 已提交
33
  }
34
}
Y
Yu Yang 已提交
35

36 37
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
38
                      const int d, const int remain) {
39 40
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
41 42 43
    int idx_n = i / d;
    int idx_remain = i % remain;
    logit_grad[i] *= loss_grad[idx_n * remain + idx_remain];
44 45 46 47 48 49
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
50 51
                                               const T* labels, const int n,
                                               const int d, const int remain) {
52
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
53 54 55 56 57
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
58
  }
C
caoying03 已提交
59
}
S
sneaxiy 已提交
60

61
}  // namespace
C
caoying03 已提交
62

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
78 79
  return math::TolerableValue<float>()(logf(x));
}
80
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

124
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
125 126
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
127
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
128
                                          int d, int axis_dim) {
S
sneaxiy 已提交
129 130
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

131 132 133 134 135 136 137 138
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
139

140
  int step = BlockDim * remain;
S
sneaxiy 已提交
141
  T cur_max = logits_data[beg_idx];
142
  beg_idx += step;
S
sneaxiy 已提交
143 144 145 146
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
147
    beg_idx += step;
S
sneaxiy 已提交
148 149 150 151 152
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

  if (threadIdx.x == 0) {
153 154
    max_data[blockIdx.x] =
        cur_max < static_cast<T>(-64) ? static_cast<T>(-64) : cur_max;
S
sneaxiy 已提交
155 156 157
  }
}

158
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
159 160
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
161 162
                                                 T* max_data, T* softmax, int d,
                                                 int axis_dim) {
S
sneaxiy 已提交
163 164
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

165 166 167 168 169 170 171 172
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
173 174

  auto block_max = max_data[blockIdx.x];
175
  int step = BlockDim * remain;
S
sneaxiy 已提交
176 177

  softmax[beg_idx] = logits_data[beg_idx] - block_max;
178
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
179
  auto idx = beg_idx + step;
S
sneaxiy 已提交
180 181
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
182
    diff_max_sum += exp_on_device(softmax[idx]);
183
    idx += step;
S
sneaxiy 已提交
184 185 186 187
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
188
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
189 190 191 192 193

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
194
  beg_idx += step;
S
sneaxiy 已提交
195 196
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
197
    beg_idx += step;
S
sneaxiy 已提交
198 199
  }
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
200 201
}

202
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
203
template <typename T, int BlockDim>
S
sneaxiy 已提交
204
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
205 206
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
    int axis_dim) {
S
sneaxiy 已提交
207 208
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

209 210 211 212 213 214 215 216
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
217 218 219 220

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
221
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
222
  auto loss = -labels_data[beg_idx] * tmp;
223 224
  int step = BlockDim * remain;
  beg_idx += step;
S
sneaxiy 已提交
225 226
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
227
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
228
    loss -= (labels_data[beg_idx] * tmp);
229
    beg_idx += step;
S
sneaxiy 已提交
230 231 232 233 234 235 236
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
237 238
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
239
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
240
                                          T* log_softmax, int d, int axis_dim)
241
      : labels_(labels),
S
sneaxiy 已提交
242 243
        loss_(loss),
        log_softmax_(log_softmax),
244 245
        d_(d),
        axis_dim_(axis_dim) {}
S
sneaxiy 已提交
246 247

  __device__ void operator()(int idx) const {
248 249 250 251 252 253 254 255
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl]) {
256
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
257 258
    } else {
      auto softmax = log_softmax_[idx];
259
      log_softmax_[idx] = exp_on_device(softmax);
260
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
261 262 263 264 265 266 267
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
268 269
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
270 271 272 273 274
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
275
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
276
                                                       T* loss, T* log_softmax,
277
                                                       int d, int axis_dim,
S
sneaxiy 已提交
278
                                                       int ignore_idx)
279
      : labels_(labels),
S
sneaxiy 已提交
280 281
        loss_(loss),
        log_softmax_(log_softmax),
282 283
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
284 285 286
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
287 288 289 290 291 292 293 294
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
295
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
296 297
    } else {
      auto softmax = log_softmax_[idx];
298
      log_softmax_[idx] = exp_on_device(softmax);
299
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
300 301 302 303 304 305 306
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
307 308
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
309 310 311 312
  int ignore_idx_;
};

template <typename T>
313
static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int n) {
S
sneaxiy 已提交
314
  auto idx = threadIdx.x + blockIdx.x * blockDim.x;
315
  if (idx < n) out[idx] = static_cast<T>(1);
S
sneaxiy 已提交
316 317
}

S
sneaxiy 已提交
318 319 320
template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
321 322
    const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d,
    int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
323
  constexpr int kMaxBlockDim = 512;
324
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
325
                      ? kMaxBlockDim
326 327
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
328 329
  auto stream = ctx.stream();

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
          labels_data, loss_data, softmax_data, d, axis_dim));             \
    }                                                                      \
S
sneaxiy 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    case 1:
358
      SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) /
S
sneaxiy 已提交
359 360
                                                kMaxBlockDim,
                                            kMaxBlockDim, 0, stream>>>(
361 362
          softmax_data, grid_dim);
      cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream);
S
sneaxiy 已提交
363 364 365 366 367 368 369 370
      break;
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
371 372 373 374
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
                                               const T* labels_data,
                                               T* softmax_data, T* loss_data,
375
                                               int n, int d, int axis_dim,
S
sneaxiy 已提交
376 377
                                               cudaStream_t stream) {
  constexpr int kMaxBlockDim = 512;
378
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
379
                      ? kMaxBlockDim
380 381 382 383 384 385 386 387 388 389 390 391
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
392 393 394 395 396 397 398 399 400 401 402 403 404
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    case 1:
405
      SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) /
S
sneaxiy 已提交
406 407
                                                kMaxBlockDim,
                                            kMaxBlockDim, 0, stream>>>(
408 409
          softmax_data, n);
      cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream);
S
sneaxiy 已提交
410 411 412 413 414 415 416 417 418
      break;
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
419
template <typename T>
Y
Yu Yang 已提交
420
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
421 422 423 424 425
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
426
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
427
    Tensor* softmax = context.Output<Tensor>("Softmax");
428
    Tensor* loss = context.Output<Tensor>("Loss");
429 430 431 432 433 434 435 436

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());

S
sneaxiy 已提交
437 438 439 440
    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

    auto soft_label = context.Attr<bool>("soft_label");
441
    auto ignore_index = context.Attr<int>("ignore_index");
442

S
sneaxiy 已提交
443 444 445 446
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
447 448
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
449
    } else {
S
sneaxiy 已提交
450
      if (!context.Attr<bool>("numeric_stable_mode")) {
451 452 453 454 455 456
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
457 458
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
459
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
460
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
461
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
462 463 464 465 466
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
467
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
468
      }
S
sneaxiy 已提交
469
    }
C
caoying03 已提交
470 471 472 473
  }
};

template <typename T>
Y
Yu Yang 已提交
474
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
475 476 477 478
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
479 480 481
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
482 483
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
484 485 486 487 488
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
489 490
    T* logit_grad_data = logit_grad->data<T>();

491 492 493 494 495 496 497
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    const int remain = d / axis_dim;
498

499
    int block = 512;
500
    auto stream = context.cuda_device_context().stream();
501
    auto ignore_index = context.Attr<int>("ignore_index");
502
    if (context.Attr<bool>("soft_label")) {
503
      int grid = (n * d + block - 1) / block;
504
      const T* label_data = labels->data<T>();
505
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
506
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
507
    } else {
508
      int grid = (n * remain + block - 1) / block;
C
caoying03 已提交
509
      const int64_t* label_data = labels->data<int64_t>();
510
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
511 512
          logit_grad_data, label_data, n, d, remain, ignore_index);
      int num = n * d;
513 514
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
515
                                           d, remain);
516
    }
C
caoying03 已提交
517 518 519 520 521 522 523
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
524 525 526 527 528 529 530 531 532
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);