softmax_with_cross_entropy_op.cu 20.1 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
C
caoying03 已提交
6 7 8 9 10
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sneaxiy 已提交
11 12
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
Y
Yi Wang 已提交
13
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
14
#include "paddle/fluid/platform/for_range.h"
15

C
caoying03 已提交
16 17 18 19 20
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

21
namespace {
C
caoying03 已提交
22
template <typename T>
23
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
24 25
                                 const int batch_size, const int class_num,
                                 const int ignore_index) {
26 27 28
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
       i += blockDim.x * gridDim.x) {
    int idx = i * class_num + labels[i];
B
Bai Yifan 已提交
29 30
    logit_grad[idx] -=
        ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
Y
Yu Yang 已提交
31
  }
32
}
Y
Yu Yang 已提交
33

34 35 36 37 38 39
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
                      const int class_num) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
    logit_grad[i] *= loss_grad[i / class_num];
40 41 42 43 44 45 46 47 48 49 50 51
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
                                               const T* labels,
                                               const int batch_size,
                                               const int class_num) {
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
  if (ids < batch_size * class_num) {
    int row_ids = ids / class_num;
C
caoying03 已提交
52
    logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
53
  }
C
caoying03 已提交
54
}
S
sneaxiy 已提交
55

56
}  // namespace
C
caoying03 已提交
57

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
static __device__ __forceinline__ platform::float16 exp_on_device(
    platform::float16 x) {
  return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
  return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
  return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
    platform::float16 x) {
  return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
S
sneaxiy 已提交
73 74
  return math::TolerableValue<float>()(logf(x));
}
75
static __device__ __forceinline__ double log_on_device(double x) {
S
sneaxiy 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  return math::TolerableValue<double>()(log(x));
}

/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
 * and loss **/
/*
  Supposing the x is `logits` and y is `labels`, the equations are as
followings:
  cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
        = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
        = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
        = \sum_{j}(-y_i_j * tmp_i_j)
  softmax_i_j = e^{tmp_i_j}
where:
  max_i = \max_{j}{x_i_j}
  logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
  tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/

// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
    cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;

template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;

// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
S
sneaxiy 已提交
122 123
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
                                          int feature_size) {
S
sneaxiy 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

  auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
  auto end_idx = feature_size * (blockIdx.x + 1);

  T cur_max = logits_data[beg_idx];
  beg_idx += BlockDim;
  while (beg_idx < end_idx) {
    if (cur_max < logits_data[beg_idx]) {
      cur_max = logits_data[beg_idx];
    }
    beg_idx += BlockDim;
  }

  cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());

  if (threadIdx.x == 0) {
141 142
    max_data[blockIdx.x] =
        cur_max < static_cast<T>(-64) ? static_cast<T>(-64) : cur_max;
S
sneaxiy 已提交
143 144 145 146
  }
}

// Make sure that BlockDim <= feature_size
S
sneaxiy 已提交
147 148 149 150
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
                                                 T* max_data, T* softmax,
                                                 int feature_size) {
S
sneaxiy 已提交
151 152 153 154 155 156 157 158
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

  auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
  auto end_idx = feature_size * (blockIdx.x + 1);

  auto block_max = max_data[blockIdx.x];

  softmax[beg_idx] = logits_data[beg_idx] - block_max;
159
  T diff_max_sum = exp_on_device(softmax[beg_idx]);
S
sneaxiy 已提交
160 161 162
  auto idx = beg_idx + BlockDim;
  while (idx < end_idx) {
    softmax[idx] = logits_data[idx] - block_max;
163
    diff_max_sum += exp_on_device(softmax[idx]);
S
sneaxiy 已提交
164
    idx += BlockDim;
S
sneaxiy 已提交
165 166 167 168
  }

  diff_max_sum =
      BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
169
  if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
S
sneaxiy 已提交
170 171 172 173 174 175 176 177 178 179 180

  if (!CalculateLogSoftmax) return;
  __syncthreads();
  diff_max_sum = max_data[blockIdx.x];
  softmax[beg_idx] -= diff_max_sum;
  beg_idx += BlockDim;
  while (beg_idx < end_idx) {
    softmax[beg_idx] -= diff_max_sum;
    beg_idx += BlockDim;
  }
  if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
S
sneaxiy 已提交
181 182 183 184
}

// Make sure that BlockDim <= feature_size
template <typename T, int BlockDim>
S
sneaxiy 已提交
185 186 187
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
    const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
    int feature_size) {
S
sneaxiy 已提交
188 189 190 191 192 193 194 195
  __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;

  auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
  auto end_idx = feature_size * (blockIdx.x + 1);

  // log_diff_max_sum shares memory with loss
  auto block_log_diff_max_sum = loss_data[blockIdx.x];
  auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
196
  softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
197 198 199 200
  auto loss = -labels_data[beg_idx] * tmp;
  beg_idx += BlockDim;
  while (beg_idx < end_idx) {
    tmp = softmax[beg_idx] - block_log_diff_max_sum;
201
    softmax[beg_idx] = exp_on_device(tmp);
S
sneaxiy 已提交
202 203 204 205 206 207 208 209 210
    loss -= (labels_data[beg_idx] * tmp);
    beg_idx += BlockDim;
  }

  loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
  if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}

template <typename T>
S
sneaxiy 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
struct HardLabelSoftmaxWithCrossEntropyFunctor {
 public:
  HardLabelSoftmaxWithCrossEntropyFunctor(const T* logits,
                                          const int64_t* labels, T* loss,
                                          T* log_softmax, int feature_size)
      : logits_(logits),
        labels_(labels),
        loss_(loss),
        log_softmax_(log_softmax),
        feature_size_(feature_size) {}

  __device__ void operator()(int idx) const {
    auto row_idx = idx / feature_size_;
    auto col_idx = idx % feature_size_;
    if (col_idx != labels_[row_idx]) {
226
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
227 228
    } else {
      auto softmax = log_softmax_[idx];
229
      log_softmax_[idx] = exp_on_device(softmax);
S
sneaxiy 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
      loss_[row_idx] = -softmax;
    }
  }

 private:
  const T* logits_;
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
  int feature_size_;
};

template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
 public:
  HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const T* logits,
                                                       const int64_t* labels,
                                                       T* loss, T* log_softmax,
                                                       int feature_size,
                                                       int ignore_idx)
      : logits_(logits),
        labels_(labels),
        loss_(loss),
        log_softmax_(log_softmax),
        feature_size_(feature_size),
        ignore_idx_(ignore_idx) {}

  __device__ void operator()(int idx) const {
    auto row_idx = idx / feature_size_;
    auto col_idx = idx % feature_size_;
    if (col_idx != labels_[row_idx] || col_idx == ignore_idx_) {
261
      log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
S
sneaxiy 已提交
262 263
    } else {
      auto softmax = log_softmax_[idx];
264
      log_softmax_[idx] = exp_on_device(softmax);
S
sneaxiy 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
      loss_[row_idx] = -softmax;
    }
  }

 private:
  const T* logits_;
  const int64_t* labels_;
  T* loss_;
  T* log_softmax_;
  int feature_size_;
  int ignore_idx_;
};

template <typename T>
static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out,
                                                           int batch_size) {
S
sneaxiy 已提交
281 282 283 284
  auto idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < batch_size) out[idx] = static_cast<T>(1);
}

S
sneaxiy 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
    const platform::CUDADeviceContext& ctx, const T* logits_data,
    const int64_t* labels_data, T* loss_data, T* softmax_data, int batch_size,
    int feature_size, int ignore_idx) {
  constexpr int kMaxBlockDim = 512;
  int block_dim = feature_size >= kMaxBlockDim
                      ? kMaxBlockDim
                      : (1 << static_cast<int>(std::log2(feature_size)));
  auto stream = ctx.stream();

#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)    \
  case BlockDim: {                                                           \
    RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, feature_size);                               \
    RowReductionForDiffMaxSum<T, BlockDim,                                   \
                              true><<<batch_size, BlockDim, 0, stream>>>(    \
        logits_data, loss_data, softmax_data, feature_size);                 \
    platform::ForRange<platform::CUDADeviceContext> for_range(               \
        ctx, batch_size* feature_size);                                      \
    if (ignore_idx >= 0 && ignore_idx < feature_size) {                      \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>(     \
          logits_data, labels_data, loss_data, softmax_data, feature_size,   \
          ignore_idx));                                                      \
    } else {                                                                 \
      for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>(                  \
          logits_data, labels_data, loss_data, softmax_data, feature_size)); \
    }                                                                        \
  } break

  switch (block_dim) {
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    case 1:
      SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) /
                                                kMaxBlockDim,
                                            kMaxBlockDim, 0, stream>>>(
          softmax_data, batch_size);
      cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream);
      break;
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

S
sneaxiy 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
                                               const T* labels_data,
                                               T* softmax_data, T* loss_data,
                                               int batch_size, int feature_size,
                                               cudaStream_t stream) {
  constexpr int kMaxBlockDim = 512;
  int block_dim = feature_size >= kMaxBlockDim
                      ? kMaxBlockDim
                      : (1 << static_cast<int>(std::log2(feature_size)));

#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim)                \
  case BlockDim:                                                              \
    RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>(     \
        logits_data, loss_data, feature_size);                                \
    RowReductionForDiffMaxSum<T,                                              \
                              BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
        logits_data, loss_data, softmax_data, feature_size);                  \
    RowReductionForSoftmaxAndCrossEntropy<                                    \
        T, BlockDim><<<batch_size, BlockDim, 0, stream>>>(                    \
        logits_data, labels_data, loss_data, softmax_data, feature_size);     \
    break

  switch (block_dim) {
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
    CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
    case 1:
      SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) /
                                                kMaxBlockDim,
                                            kMaxBlockDim, 0, stream>>>(
          softmax_data, batch_size);
S
sneaxiy 已提交
377
      cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream);
S
sneaxiy 已提交
378 379 380 381 382 383 384 385 386
      break;
    default:
      PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
      break;
  }

#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}

C
caoying03 已提交
387
template <typename T>
Y
Yu Yang 已提交
388
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
389 390 391 392 393
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
394
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
395 396
    Tensor* softmax = context.Output<Tensor>("Softmax");

397
    Tensor* loss = context.Output<Tensor>("Loss");
S
sneaxiy 已提交
398 399 400 401
    auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
    auto* loss_data = loss->mutable_data<T>(context.GetPlace());

    auto soft_label = context.Attr<bool>("soft_label");
402
    auto ignore_index = context.Attr<int>("ignore_index");
403 404

    int rank = logits->dims().size();
S
sneaxiy 已提交
405
    if (soft_label) {
406 407 408 409 410 411
      int batch_size = 1;
      for (int i = 0; i < rank - 1; ++i) {
        batch_size *= logits->dims()[i];
      }

      int feature_size = logits->dims()[rank - 1];
S
sneaxiy 已提交
412 413 414 415 416 417
      auto* logits_data = logits->data<T>();
      auto* labels_data = labels->data<T>();
      SoftmaxWithCrossEntropyFusedKernel(
          logits_data, labels_data, softmax_data, loss_data, batch_size,
          feature_size, context.cuda_device_context().stream());
    } else {
S
sneaxiy 已提交
418
      if (!context.Attr<bool>("numeric_stable_mode")) {
419 420 421 422 423 424 425 426
        // reshape to 2d
        Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
        Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
        Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
        Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);

        math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
                                       &logits_2d, &softmax_2d);
S
sneaxiy 已提交
427
        math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
428 429
            context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
            false, ignore_index);
S
sneaxiy 已提交
430
      } else {
431 432 433 434 435
        int batch_size = 1;
        for (int i = 0; i < rank - 1; ++i) {
          batch_size *= logits->dims()[i];
        }
        int feature_size = logits->dims()[rank - 1];
S
sneaxiy 已提交
436 437 438 439 440 441
        auto* logits_data = logits->data<T>();
        auto* labels_data = labels->data<int64_t>();
        HardLabelSoftmaxWithCrossEntropy<T>(
            context.cuda_device_context(), logits_data, labels_data, loss_data,
            softmax_data, batch_size, feature_size, ignore_index);
      }
S
sneaxiy 已提交
442
    }
C
caoying03 已提交
443 444 445 446
  }
};

template <typename T>
Y
Yu Yang 已提交
447
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
448 449 450 451
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
452 453 454
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
455 456
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
C
chengduo 已提交
457 458
    framework::TensorCopy(*context.Input<Tensor>("Softmax"), context.GetPlace(),
                          context.device_context(), logit_grad);
C
caoying03 已提交
459 460
    T* logit_grad_data = logit_grad->data<T>();

461 462 463 464 465 466 467
    int rank = logit_grad->dims().size();
    int batch_size = 1;
    for (int i = 0; i < rank - 1; ++i) {
      batch_size *= logit_grad->dims()[i];
    }

    const int class_num = logit_grad->dims()[rank - 1];
468
    int block = 512;
469
    auto stream = context.cuda_device_context().stream();
470
    auto ignore_index = context.Attr<int>("ignore_index");
471
    if (context.Attr<bool>("soft_label")) {
472
      int grid = (batch_size * class_num + block - 1) / block;
473
      const T* label_data = labels->data<T>();
474 475
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
          logit_grad_data, loss_grad_data, label_data, batch_size, class_num);
476
    } else {
477
      int grid = (batch_size + block - 1) / block;
C
caoying03 已提交
478
      const int64_t* label_data = labels->data<int64_t>();
479
      CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
480
          logit_grad_data, label_data, batch_size, class_num, ignore_index);
481 482 483 484
      int num = batch_size * class_num;
      grid = (num + block - 1) / block;
      Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
                                           class_num);
485
    }
C
caoying03 已提交
486 487 488 489 490 491 492
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
493 494 495 496 497 498 499 500 501
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
    softmax_with_cross_entropy_grad,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
    ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);