You need to sign in or sign up before continuing.
softmax_with_cross_entropy_op.cu 4.7 KB
Newer Older
C
caoying03 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#define EIGEN_USE_GPU
C
caoying03 已提交
16

17
#include "paddle/operators/softmax_with_cross_entropy_op.h"
18

C
caoying03 已提交
19 20 21 22 23
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

24
namespace {
C
caoying03 已提交
25
template <typename T>
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
                                 const int* labels, const int batch_size,
                                 const int class_num) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int sample_idx = tid / class_num;

  if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
  __syncthreads();

  if (tid < batch_size) {
    PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
    out_grad[tid * class_num + labels[tid]] -= 1.;
  }
}

template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
                                               const T* loss_grad,
                                               const T* labels,
                                               const int batch_size,
                                               const int class_num) {
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
  if (ids < batch_size * class_num) {
    int row_ids = ids / class_num;
    logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
C
caoying03 已提交
51
  }
C
caoying03 已提交
52
}
53
}  // namespace
C
caoying03 已提交
54 55 56 57 58 59 60 61

template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    const Tensor* logits = context.Input<Tensor>("Logits");
62
    const Tensor* labels = context.Input<Tensor>("Label");
C
caoying03 已提交
63 64
    Tensor* softmax = context.Output<Tensor>("Softmax");

65 66 67
    Tensor* loss = context.Output<Tensor>("Loss");
    softmax->mutable_data<T>(context.GetPlace());
    loss->mutable_data<T>(context.GetPlace());
68

69 70 71
    math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
    math::CrossEntropyFunctor<platform::GPUPlace, T>()(
        context, loss, softmax, labels, context.Attr<bool>("softLabel"));
C
caoying03 已提交
72 73 74 75 76 77 78 79 80
  }
};

template <typename T>
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
81 82 83
    const Tensor* labels = context.Input<Tensor>("Label");
    const T* loss_grad_data =
        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
C
caoying03 已提交
84 85 86 87 88 89 90
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));
    logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
    T* logit_grad_data = logit_grad->data<T>();

    const int batch_size = logit_grad->dims()[0];
    const int class_num = logit_grad->dims()[1];
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    int block = 512;
    int grid = (batch_size * class_num + block - 1) / block;

    if (context.Attr<bool>("softLabel")) {
      const T* label_data = labels->data<T>();
      SoftCrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              context.device_context())
                              .stream()>>>(logit_grad_data, loss_grad_data,
                                           label_data, batch_size, class_num);
    } else {
      const int* label_data = labels->data<int>();
      CrossEntropyGrad<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              context.device_context())
                              .stream()>>>(logit_grad_data, loss_grad_data,
                                           label_data, batch_size, class_num);
    }
C
caoying03 已提交
109 110 111 112 113 114 115 116 117 118 119
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy,
                       ops::SoftmaxWithCrossEntropyCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad,
                       ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>);