cross_entropy_op.cu 3.9 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/operators/cross_entropy_op.h"
16 17 18 19

namespace paddle {
namespace operators {

20
namespace {
21 22 23

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
24
                                           const int64_t* label, const int N,
25 26 27 28 29 30 31 32 33
                                           const int D) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
34 35 36
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
37
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
C
caoying03 已提交
38
  if (ids < N * D) {
C
caoying03 已提交
39
    int row_ids = ids / D;
C
caoying03 已提交
40
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
41 42
  }
}
43
}  // namespace
44 45

template <typename T>
Y
Yu Yang 已提交
46
class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
47 48 49
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
50
                   "This kernel only runs on GPU device.");
C
caoying03 已提交
51 52 53
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* y = ctx.Output<Tensor>("Y");
54
    y->mutable_data<T>(ctx.GetPlace());
55

56
    math::CrossEntropyFunctor<platform::GPUPlace, T>()(
57
        ctx.device_context(), y, x, label, ctx.Attr<bool>("soft_label"));
58 59 60 61
  }
};

template <typename T>
Y
Yu Yang 已提交
62
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
63 64 65
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
66
                   "This kernel only runs on GPU device.");
67

C
caoying03 已提交
68 69 70
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
71
    dx->mutable_data<T>(ctx.GetPlace());
72

C
caoying03 已提交
73 74 75 76
    const T* dy_data =
        ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    const T* x_data = x->data<T>();
77

78 79
    int64_t batch_size = x->dims()[0];
    int64_t class_num = x->dims()[1];
C
caoying03 已提交
80

81
    int block = 512;
C
caoying03 已提交
82
    int grid = (batch_size * class_num + block - 1) / block;
T
typhoonzero 已提交
83
    auto stream = ctx.cuda_device_context().stream();
C
caoying03 已提交
84

85
    if (ctx.Attr<bool>("soft_label")) {
86
      auto* label_data = label->data<T>();
T
typhoonzero 已提交
87 88
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
          dx_data, dy_data, x_data, label_data, batch_size, class_num);
89
    } else {
Q
qijun 已提交
90 91
      math::SetConstant<platform::GPUPlace, T> functor;
      functor(ctx.device_context(), dx, 0);
92
      auto* label_data = label->data<int64_t>();
C
caoying03 已提交
93
      grid = (batch_size + block - 1) / block;
T
typhoonzero 已提交
94 95
      CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
          dx_data, dy_data, x_data, label_data, batch_size, class_num);
96
    }
97 98 99 100 101
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
102

D
dongzhihong 已提交
103
namespace ops = paddle::operators;
104 105
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
                       ops::CrossEntropyOpCUDAKernel<double>);
106
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
107 108
                       ops::CrossEntropyGradientOpCUDAKernel<float>,
                       ops::CrossEntropyGradientOpCUDAKernel<double>);