cross_entropy_op.cu 4.2 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/operators/cross_entropy_op.h"
16 17 18 19

namespace paddle {
namespace operators {

20
namespace {
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
36 37 38
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
39
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
C
caoying03 已提交
40
  if (ids < N * D) {
C
caoying03 已提交
41
    int row_ids = ids / D;
C
caoying03 已提交
42
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
43 44
  }
}
45
}  // namespace
46 47

template <typename T>
Y
Yu Yang 已提交
48
class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
49 50 51
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
52
                   "This kernel only runs on GPU device.");
C
caoying03 已提交
53 54 55
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* y = ctx.Output<Tensor>("Y");
56
    y->mutable_data<T>(ctx.GetPlace());
57

58
    math::CrossEntropyFunctor<platform::GPUPlace, T>()(
59
        ctx.device_context(), y, x, label, ctx.Attr<bool>("soft_label"));
60 61 62 63
  }
};

template <typename T>
Y
Yu Yang 已提交
64
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
65 66 67
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
68
                   "This kernel only runs on GPU device.");
69

C
caoying03 已提交
70 71 72
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
73
    dx->mutable_data<T>(ctx.GetPlace());
74

C
caoying03 已提交
75 76 77 78
    const T* dy_data =
        ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    const T* x_data = x->data<T>();
79

C
caoying03 已提交
80 81
    int batch_size = x->dims()[0];
    int class_num = x->dims()[1];
C
caoying03 已提交
82

83
    int block = 512;
C
caoying03 已提交
84 85
    int grid = (batch_size * class_num + block - 1) / block;

86
    if (ctx.Attr<bool>("soft_label")) {
87
      auto* label_data = label->data<T>();
C
caoying03 已提交
88 89 90 91
      SoftCrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
C
caoying03 已提交
92
                                           batch_size, class_num);
93
    } else {
Q
qijun 已提交
94 95
      math::SetConstant<platform::GPUPlace, T> functor;
      functor(ctx.device_context(), dx, 0);
96
      auto* label_data = label->data<int>();
C
caoying03 已提交
97
      grid = (batch_size + block - 1) / block;
C
caoying03 已提交
98 99 100 101
      CrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
C
caoying03 已提交
102
                                           batch_size, class_num);
103
    }
104 105 106 107 108
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
109

D
dongzhihong 已提交
110
namespace ops = paddle::operators;
111 112 113
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpCUDAKernel<float>);