softmax_with_cross_entropy_op.cu 4.8 KB
Newer Older
C
caoying03 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#define EIGEN_USE_GPU
C
caoying03 已提交
16

Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
18

C
caoying03 已提交
19 20 21 22 23
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

24
namespace {
C
caoying03 已提交
25
template <typename T>
Y
Yu Yang 已提交
26
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
C
caoying03 已提交
27
                                 const int64_t* labels, const int batch_size,
28 29 30 31 32 33
                                 const int class_num) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int sample_idx = tid / class_num;

  if (tid < batch_size) {
    PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
Y
Yu Yang 已提交
34 35 36 37 38 39 40
    logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
  }

  __syncthreads();

  if (tid < batch_size * class_num) {
    logit_grad[tid] *= loss_grad[sample_idx];
41 42 43 44 45 46 47 48 49 50 51 52
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
                                               const T* labels,
                                               const int batch_size,
                                               const int class_num) {
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
  if (ids < batch_size * class_num) {
    int row_ids = ids / class_num;
C
caoying03 已提交
53
    logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]);
C
caoying03 已提交
54
  }
C
caoying03 已提交
55
}
56
}  // namespace
C
caoying03 已提交
57 58

template <typename T>
Y
Yu Yang 已提交
59
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
60 61 62 63 64
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
65
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
66 67
    Tensor* softmax = context.Output<Tensor>("Softmax");

68 69 70
    Tensor* loss = context.Output<Tensor>("Loss");
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
71

Q
QI JUN 已提交
72 73 74 75
    math::SoftmaxFunctor<platform::CUDADeviceContext, T>()(
        context.cuda_device_context(), logits, softmax);
    math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
        context.cuda_device_context(), loss, softmax, labels,
76
        context.Attr<bool>("soft_label"));
C
caoying03 已提交
77 78 79 80
  }
};

template <typename T>
Y
Yu Yang 已提交
81
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
C
caoying03 已提交
82 83 84 85
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
86 87 88
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
89 90
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
91
    logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
C
caoying03 已提交
92 93 94 95
    T* logit_grad_data = logit_grad->data<T>();

    const int batch_size = logit_grad->dims()[0];
    const int class_num = logit_grad->dims()[1];
96 97 98
    int block = 512;
    int grid = (batch_size * class_num + block - 1) / block;

99
    if (context.Attr<bool>("soft_label")) {
100
      const T* label_data = labels->data<T>();
Q
QI JUN 已提交
101 102 103 104 105
      SoftCrossEntropyGradientKernel<
          T><<<grid, block, 0,
               context.template device_context<platform::CUDADeviceContext>()
                   .stream()>>>(logit_grad_data, loss_grad_data, label_data,
                                batch_size, class_num);
106
    } else {
C
caoying03 已提交
107
      const int64_t* label_data = labels->data<int64_t>();
Q
QI JUN 已提交
108 109 110 111 112
      CrossEntropyGrad<
          T><<<grid, block, 0,
               context.template device_context<platform::CUDADeviceContext>()
                   .stream()>>>(logit_grad_data, loss_grad_data, label_data,
                                batch_size, class_num);
113
    }
C
caoying03 已提交
114 115 116 117 118 119 120
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
121 122 123 124 125 126
REGISTER_OP_CUDA_KERNEL(softmax_with_cross_entropy,
                        ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
                        ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(softmax_with_cross_entropy_grad,
                        ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
                        ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);