cross_entropy_op.cu 4.0 KB
Newer Older
L
liaogang 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
L
liaogang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
L
liaogang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/cross_entropy_op.h"
16 17 18 19

namespace paddle {
namespace operators {

20
namespace {
21 22 23

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
24
                                           const int64_t* label, const int N,
25 26 27 28 29 30 31 32 33
                                           const int D) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
34 35 36
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
37
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
C
caoying03 已提交
38
  if (ids < N * D) {
C
caoying03 已提交
39
    int row_ids = ids / D;
C
caoying03 已提交
40
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
41 42
  }
}
43
}  // namespace
44 45

template <typename T>
Y
Yu Yang 已提交
46
class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
47 48 49
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
50
                   "This kernel only runs on GPU device.");
C
caoying03 已提交
51 52 53
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* y = ctx.Output<Tensor>("Y");
54
    y->mutable_data<T>(ctx.GetPlace());
55

Q
QI JUN 已提交
56 57 58
    math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
        ctx.template device_context<platform::CUDADeviceContext>(), y, x, label,
        ctx.Attr<bool>("soft_label"));
59 60 61 62
  }
};

template <typename T>
Y
Yu Yang 已提交
63
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
64 65 66
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
67
                   "This kernel only runs on GPU device.");
68

C
caoying03 已提交
69 70 71
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
72
    dx->mutable_data<T>(ctx.GetPlace());
73

C
caoying03 已提交
74 75 76 77
    const T* dy_data =
        ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    const T* x_data = x->data<T>();
78

79 80
    int64_t batch_size = x->dims()[0];
    int64_t class_num = x->dims()[1];
C
caoying03 已提交
81

82
    int block = 512;
C
caoying03 已提交
83
    int grid = (batch_size * class_num + block - 1) / block;
Q
QI JUN 已提交
84 85 86

    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto stream = dev_ctx.stream();
C
caoying03 已提交
87

88
    if (ctx.Attr<bool>("soft_label")) {
89
      auto* label_data = label->data<T>();
T
typhoonzero 已提交
90 91
      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
          dx_data, dy_data, x_data, label_data, batch_size, class_num);
92
    } else {
Q
QI JUN 已提交
93 94
      math::SetConstant<platform::CUDADeviceContext, T> functor;
      functor(dev_ctx, dx, 0);
95
      auto* label_data = label->data<int64_t>();
C
caoying03 已提交
96
      grid = (batch_size + block - 1) / block;
T
typhoonzero 已提交
97 98
      CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
          dx_data, dy_data, x_data, label_data, batch_size, class_num);
99
    }
100 101 102 103 104
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
105

D
dongzhihong 已提交
106
namespace ops = paddle::operators;
Q
QI JUN 已提交
107 108 109 110 111
REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
                        ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(cross_entropy_grad,
                        ops::CrossEntropyGradientOpCUDAKernel<float>,
                        ops::CrossEntropyGradientOpCUDAKernel<double>);