softmax_with_cross_entropy_op.cu 20.9 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
13
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
14
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
15
#include "paddle/fluid/platform/for_range.h"
16

C
caoying03 已提交
17 18 19 20 21
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

22
namespace {
C
caoying03 已提交
23
template <typename T>
24
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
25
                                 const int n, const int d, const int remain,
26
                                 const int ignore_index) {
27 28 29 30
  CUDA_KERNEL_LOOP(index, n * remain) {
    int idx_n = index / remain;
    int idx_remain = index % remain;
    int idx = idx_n * d + labels[index] * remain + idx_remain;
B
Bai Yifan 已提交
31
    logit_grad[idx] -=
32
        ignore_index == labels[index] ? static_cast<T>(0.) : static_cast<T>(1.);
Y
Yu Yang 已提交
33
  }
34
}
Y
Yu Yang 已提交
35

36 37
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
38
                      const int d, const int remain) {
39 40 41 42
  CUDA_KERNEL_LOOP(index, num) {
    int idx_n = index / d;
    int idx_remain = index % remain;
    logit_grad[index] *= loss_grad[idx_n * remain + idx_remain];
43 44 45 46 47 48
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
49 50
                                               const T* labels, const int n,
                                               const int d, const int remain) {
51
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
52 53 54 55 56
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
57
  }
C
caoying03 已提交
58
}
S
sneaxiy 已提交
59

60
}  // namespace
C
caoying03 已提交
61

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
77 78
  return math::TolerableValue<float>()(logf(x));
}
79
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
T
tianshuo78520a 已提交
101
Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
S
sneaxiy 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

123
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
124 125
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
126
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
127
                                          int d, int axis_dim) {
S
sneaxiy 已提交
128 129
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

130 131 132 133 134 135 136 137
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
138

139
  int step = BlockDim * remain;
S
sneaxiy 已提交
140
  T cur_max = logits_data[beg_idx];
141
  beg_idx += step;
S
sneaxiy 已提交
142 143 144 145
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
146
    beg_idx += step;
S
sneaxiy 已提交
147 148 149 150
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

151
  if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
S
sneaxiy 已提交
152 153
}

154
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
155 156
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
157 158
                                                 T* max_data, T* softmax, int d,
                                                 int axis_dim) {
S
sneaxiy 已提交
159 160
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

161 162 163 164 165 166 167 168
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
169 170

  auto block_max = max_data[blockIdx.x];
171
  int step = BlockDim * remain;
S
sneaxiy 已提交
172

173 174 175 176 177 178
  // In numeric stable mode softmax_with_loss, we calc loss with
  // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
  // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
  // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
  // be 1.0 and 0.0, represent prob is 1.0 and 0.0.
  // So there is no need to clip on shift_softmax.
S
sneaxiy 已提交
179
  softmax[beg_idx] = logits_data[beg_idx] - block_max;
180
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
181
  auto idx = beg_idx + step;
S
sneaxiy 已提交
182 183
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
184
    diff_max_sum += exp_on_device(softmax[idx]);
185
    idx += step;
S
sneaxiy 已提交
186 187 188 189
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
190
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
191 192 193 194 195

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
196
  beg_idx += step;
S
sneaxiy 已提交
197 198
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
199
    beg_idx += step;
S
sneaxiy 已提交
200
  }
201 202 203 204

  // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
  // calculate diff_max_sum, __syncthreads() is needed here.
  __syncthreads();
S
sneaxiy 已提交
205
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
206 207
}

208
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
209
template <typename T, int BlockDim>
S
sneaxiy 已提交
210
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
211 212
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
    int axis_dim) {
S
sneaxiy 已提交
213 214
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

215 216 217 218 219 220 221 222
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
223 224 225 226

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
227
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
228
  auto loss = -labels_data[beg_idx] * tmp;
229 230
  int step = BlockDim * remain;
  beg_idx += step;
S
sneaxiy 已提交
231 232
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
233
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
234
    loss -= (labels_data[beg_idx] * tmp);
235
    beg_idx += step;
S
sneaxiy 已提交
236 237 238 239 240 241 242
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
243 244
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
245
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
246
                                          T* log_softmax, int d, int axis_dim)
247
      : labels_(labels),
S
sneaxiy 已提交
248 249
        loss_(loss),
        log_softmax_(log_softmax),
250 251
        d_(d),
        axis_dim_(axis_dim) {}
S
sneaxiy 已提交
252 253

  __device__ void operator()(int idx) const {
254 255 256 257 258 259 260 261
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl]) {
262
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
263 264
    } else {
      auto softmax = log_softmax_[idx];
265
      log_softmax_[idx] = exp_on_device(softmax);
266
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
267 268 269 270 271 272 273
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
274 275
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
276 277 278 279 280
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
281
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
282
                                                       T* loss, T* log_softmax,
283
                                                       int d, int axis_dim,
S
sneaxiy 已提交
284
                                                       int ignore_idx)
285
      : labels_(labels),
S
sneaxiy 已提交
286 287
        loss_(loss),
        log_softmax_(log_softmax),
288 289
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
290 291 292
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
293 294 295 296 297 298 299 300
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
301
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
302 303
    } else {
      auto softmax = log_softmax_[idx];
304
      log_softmax_[idx] = exp_on_device(softmax);
305
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
306 307 308 309 310 311 312
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
313 314
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
315 316 317 318 319 320
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
321 322
    const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d,
    int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
323
  constexpr int kMaxBlockDim = 512;
324
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
325
                      ? kMaxBlockDim
326 327
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
328 329
  auto stream = ctx.stream();

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
          labels_data, loss_data, softmax_data, d, axis_dim));             \
    }                                                                      \
S
sneaxiy 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
364 365 366 367
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
                                               const T* labels_data,
                                               T* softmax_data, T* loss_data,
368
                                               int n, int d, int axis_dim,
S
sneaxiy 已提交
369 370
                                               cudaStream_t stream) {
  constexpr int kMaxBlockDim = 512;
371
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
372
                      ? kMaxBlockDim
373 374 375 376 377 378 379 380 381 382 383 384
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
405
template <typename T>
Y
Yu Yang 已提交
406
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
407 408 409 410 411
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
412
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
413
    Tensor* softmax = context.Output<Tensor>("Softmax");
414
    Tensor* loss = context.Output<Tensor>("Loss");
415 416 417 418 419

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

420 421 422 423 424 425
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());

    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

426 427 428 429 430 431 432
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

S
sneaxiy 已提交
433
    auto soft_label = context.Attr<bool>("soft_label");
434
    auto ignore_index = context.Attr<int>("ignore_index");
435

S
sneaxiy 已提交
436 437 438 439
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
440 441
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
442
    } else {
S
sneaxiy 已提交
443
      if (!context.Attr<bool>("numeric_stable_mode")) {
444 445 446 447 448 449
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
450 451
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
452
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
453
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
454
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
455 456 457 458 459
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
460
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
461
      }
S
sneaxiy 已提交
462
    }
C
caoying03 已提交
463 464 465 466
  }
};

template <typename T>
Y
Yu Yang 已提交
467
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
468 469 470 471
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
472 473 474
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
475 476
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
477 478 479 480 481
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
482 483
    T* logit_grad_data = logit_grad->data<T>();

484 485 486 487 488 489 490
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    const int remain = d / axis_dim;
491

492
    int block = 512;
493
    auto stream = context.cuda_device_context().stream();
494
    auto ignore_index = context.Attr<int>("ignore_index");
495
    if (context.Attr<bool>("soft_label")) {
496
      int grid = (n * d + block - 1) / block;
497
      const T* label_data = labels->data<T>();
498
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
499
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
500
    } else {
501
      int grid = (n * remain + block - 1) / block;
C
caoying03 已提交
502
      const int64_t* label_data = labels->data<int64_t>();
503
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
504 505
          logit_grad_data, label_data, n, d, remain, ignore_index);
      int num = n * d;
506 507
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
508
                                           d, remain);
509
    }
C
caoying03 已提交
510 511 512 513 514 515 516
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
517 518 519 520 521 522 523 524 525
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);