softmax_with_cross_entropy_op.cu 20.6 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
13
#include "paddle/fluid/operators/math/math_function.h"
Y
Yi Wang 已提交
14
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
15
#include "paddle/fluid/platform/for_range.h"
16

C
caoying03 已提交
17 18 19 20 21
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

22
namespace {
C
caoying03 已提交
23
template <typename T>
24
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
25
                                 const int n, const int d, const int remain,
26
                                 const int ignore_index) {
27
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n * remain;
28
       i += blockDim.x * gridDim.x) {
29 30 31
    int idx_n = i / remain;
    int idx_remain = i % remain;
    int idx = idx_n * d + labels[i] * remain + idx_remain;
B
Bai Yifan 已提交
32 33
    logit_grad[idx] -=
        ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
Y
Yu Yang 已提交
34
  }
35
}
Y
Yu Yang 已提交
36

37 38
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
39
                      const int d, const int remain) {
40 41
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
42 43 44
    int idx_n = i / d;
    int idx_remain = i % remain;
    logit_grad[i] *= loss_grad[idx_n * remain + idx_remain];
45 46 47 48 49 50
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
51 52
                                               const T* labels, const int n,
                                               const int d, const int remain) {
53
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
54 55 56 57 58
  if (ids < n * d) {
    int idx_n = ids / d;
    int idx_remain = ids % remain;
    int idx_loss = idx_n * remain + idx_remain;
    logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
59
  }
C
caoying03 已提交
60
}
S
sneaxiy 已提交
61

62
}  // namespace
C
caoying03 已提交
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
79 80
  return math::TolerableValue<float>()(logf(x));
}
81
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

125
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
126 127
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
128
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
129
                                          int d, int axis_dim) {
S
sneaxiy 已提交
130 131
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

132 133 134 135 136 137 138 139
  // logits_data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
140

141
  int step = BlockDim * remain;
S
sneaxiy 已提交
142
  T cur_max = logits_data[beg_idx];
143
  beg_idx += step;
S
sneaxiy 已提交
144 145 146 147
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
148
    beg_idx += step;
S
sneaxiy 已提交
149 150 151 152 153
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

  if (threadIdx.x == 0) {
154 155
    max_data[blockIdx.x] =
        cur_max < static_cast<T>(-64) ? static_cast<T>(-64) : cur_max;
S
sneaxiy 已提交
156 157 158
  }
}

159
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
160 161
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
162 163
                                                 T* max_data, T* softmax, int d,
                                                 int axis_dim) {
S
sneaxiy 已提交
164 165
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

166 167 168 169 170 171 172 173
  // logits, softmax data view as [n, axis_dim, remain]
  // max_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
174 175

  auto block_max = max_data[blockIdx.x];
176
  int step = BlockDim * remain;
S
sneaxiy 已提交
177 178

  softmax[beg_idx] = logits_data[beg_idx] - block_max;
179
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
180
  auto idx = beg_idx + step;
S
sneaxiy 已提交
181 182
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
183
    diff_max_sum += exp_on_device(softmax[idx]);
184
    idx += step;
S
sneaxiy 已提交
185 186 187 188
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
189
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
190 191 192 193 194

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
195
  beg_idx += step;
S
sneaxiy 已提交
196 197
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
198
    beg_idx += step;
S
sneaxiy 已提交
199 200
  }
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
201 202
}

203
// Make sure that BlockDim <= axis_dim
S
sneaxiy 已提交
204
template <typename T, int BlockDim>
S
sneaxiy 已提交
205
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
206 207
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
    int axis_dim) {
S
sneaxiy 已提交
208 209
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

210 211 212 213 214 215 216 217
  // logits, softmax, labels data view as [n, axis_dim, remain]
  // loss_data view as [n, 1, remain]
  // blockDim = n * remain, split blockIdx to idx_n and idx_remain
  int remain = d / axis_dim;
  int idx_n = blockIdx.x / remain;
  int idx_remain = blockIdx.x % remain;
  int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
  int end_idx = (idx_n + 1) * d;
S
sneaxiy 已提交
218 219 220 221

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
222
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
223
  auto loss = -labels_data[beg_idx] * tmp;
224 225
  int step = BlockDim * remain;
  beg_idx += step;
S
sneaxiy 已提交
226 227
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
228
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
229
    loss -= (labels_data[beg_idx] * tmp);
230
    beg_idx += step;
S
sneaxiy 已提交
231 232 233 234 235 236 237
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
238 239
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
240
  HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
241
                                          T* log_softmax, int d, int axis_dim)
242
      : labels_(labels),
S
sneaxiy 已提交
243 244
        loss_(loss),
        log_softmax_(log_softmax),
245 246
        d_(d),
        axis_dim_(axis_dim) {}
S
sneaxiy 已提交
247 248

  __device__ void operator()(int idx) const {
249 250 251 252 253 254 255 256
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl]) {
257
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
258 259
    } else {
      auto softmax = log_softmax_[idx];
260
      log_softmax_[idx] = exp_on_device(softmax);
261
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
262 263 264 265 266 267 268
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
269 270
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
271 272 273 274 275
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
276
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
S
sneaxiy 已提交
277
                                                       T* loss, T* log_softmax,
278
                                                       int d, int axis_dim,
S
sneaxiy 已提交
279
                                                       int ignore_idx)
280
      : labels_(labels),
S
sneaxiy 已提交
281 282
        loss_(loss),
        log_softmax_(log_softmax),
283 284
        d_(d),
        axis_dim_(axis_dim),
S
sneaxiy 已提交
285 286 287
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
288 289 290 291 292 293 294 295
    // logits view as [n, axis_dim, remain], where d = axis_dim * remain
    int remain = d_ / axis_dim_;
    int idx_n = idx / d_;
    int idx_axis = (idx % d_) / remain;
    int idx_remain = idx % remain;
    // labels, loss view as [n, remain]
    int idx_lbl = idx_n * remain + idx_remain;
    if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
296
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
297 298
    } else {
      auto softmax = log_softmax_[idx];
299
      log_softmax_[idx] = exp_on_device(softmax);
300
      loss_[idx_lbl] = -softmax;
S
sneaxiy 已提交
301 302 303 304 305 306 307
    }
  }

 private:
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
308 309
  int d_;
  int axis_dim_;
S
sneaxiy 已提交
310 311 312 313 314 315
  int ignore_idx_;
};

template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
316 317
    const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d,
    int axis_dim, int ignore_idx) {
S
sneaxiy 已提交
318
  constexpr int kMaxBlockDim = 512;
319
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
320
                      ? kMaxBlockDim
321 322
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;
S
sneaxiy 已提交
323 324
  auto stream = ctx.stream();

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)  \
  case BlockDim: {                                                         \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, d, axis_dim);                              \
    RowReductionForDiffMaxSum<T, BlockDim,                                 \
                              true><<<grid_dim, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, d, axis_dim);                \
    platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d);  \
    if (ignore_idx >= 0 && ignore_idx < axis_dim) {                        \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(   \
          labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
    } else {                                                               \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                \
          labels_data, loss_data, softmax_data, d, axis_dim));             \
    }                                                                      \
S
sneaxiy 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
359 360 361 362
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
                                               const T* labels_data,
                                               T* softmax_data, T* loss_data,
363
                                               int n, int d, int axis_dim,
S
sneaxiy 已提交
364 365
                                               cudaStream_t stream) {
  constexpr int kMaxBlockDim = 512;
366
  int block_dim = axis_dim >= kMaxBlockDim
S
sneaxiy 已提交
367
                      ? kMaxBlockDim
368 369 370 371 372 373 374 375 376 377 378 379
                      : (1 << static_cast<int>(std::log2(axis_dim)));
  int grid_dim = n * d / axis_dim;

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                 \
  case BlockDim:                                                               \
    RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(        \
        logits_data, loss_data, d, axis_dim);                                  \
    RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, d, axis_dim);                    \
    RowReductionForSoftmaxAndCrossEntropy<                                     \
        T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>(                       \
        logits_data, labels_data, loss_data, softmax_data, d, axis_dim);       \
S
sneaxiy 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
400
template <typename T>
Y
Yu Yang 已提交
401
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
402 403 404 405 406
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
407
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
408
    Tensor* softmax = context.Output<Tensor>("Softmax");
409
    Tensor* loss = context.Output<Tensor>("Loss");
410 411 412 413 414

    const int rank = logits->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logits->dims()[axis];

415 416 417 418 419 420 421
    if (axis_dim == 1) {
      math::SetConstant<platform::CUDADeviceContext, T> set_constant;
      set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
      set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
      return;
    }

422 423 424
    const int n = SizeToAxis(axis, logits->dims());
    const int d = SizeFromAxis(axis, logits->dims());

S
sneaxiy 已提交
425 426 427 428
    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

    auto soft_label = context.Attr<bool>("soft_label");
429
    auto ignore_index = context.Attr<int>("ignore_index");
430

S
sneaxiy 已提交
431 432 433 434
    if (soft_label) {
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
435 436
          logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
          context.cuda_device_context().stream());
S
sneaxiy 已提交
437
    } else {
S
sneaxiy 已提交
438
      if (!context.Attr<bool>("numeric_stable_mode")) {
439 440 441 442 443 444
        // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
        Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
        logits_2d.ShareDataWith(*logits).Resize({n, d});
        softmax_2d.ShareDataWith(*softmax).Resize({n, d});
        labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
        loss_2d.ShareDataWith(*loss).Resize({n, 1});
445 446
        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
447
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
448
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
449
            false, ignore_index, axis_dim);
S
sneaxiy 已提交
450 451 452 453 454
      } else {
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
455
            softmax_data, n, d, axis_dim, ignore_index);
S
sneaxiy 已提交
456
      }
S
sneaxiy 已提交
457
    }
C
caoying03 已提交
458 459 460 461
  }
};

template <typename T>
Y
Yu Yang 已提交
462
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
463 464 465 466
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
467 468 469
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
470 471
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
Z
Zeng Jinle 已提交
472 473 474 475 476
    const Tensor* softmax = context.Input<Tensor>("Softmax");
    if (logit_grad != softmax) {
      framework::TensorCopy(*softmax, context.GetPlace(),
                            context.device_context(), logit_grad);
    }
C
caoying03 已提交
477 478
    T* logit_grad_data = logit_grad->data<T>();

479 480 481 482 483 484 485
    const int rank = logit_grad->dims().size();
    const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
    int axis_dim = logit_grad->dims()[axis];

    const int n = SizeToAxis(axis, logit_grad->dims());
    const int d = SizeFromAxis(axis, logit_grad->dims());
    const int remain = d / axis_dim;
486

487
    int block = 512;
488
    auto stream = context.cuda_device_context().stream();
489
    auto ignore_index = context.Attr<int>("ignore_index");
490
    if (context.Attr<bool>("soft_label")) {
491
      int grid = (n * d + block - 1) / block;
492
      const T* label_data = labels->data<T>();
493
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
494
          logit_grad_data, loss_grad_data, label_data, n, d, remain);
495
    } else {
496
      int grid = (n * remain + block - 1) / block;
C
caoying03 已提交
497
      const int64_t* label_data = labels->data<int64_t>();
498
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
499 500
          logit_grad_data, label_data, n, d, remain, ignore_index);
      int num = n * d;
501 502
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
503
                                           d, remain);
504
    }
C
caoying03 已提交
505 506 507 508 509 510 511
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
512 513 514 515 516 517 518 519 520
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);