You need to sign in or sign up before continuing.
cross_entropy_op.cu 4.5 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/operators/cross_entropy_op.h"
16 17 18 19

namespace paddle {
namespace operators {

20
namespace {
C
caoying03 已提交
21
// TODO(qingqing): make zero setting a common function.
22
template <typename T>
C
caoying03 已提交
23
__global__ void Zero(T* X, const int N) {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    X[i] = 0.0;
  }
}

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
44 45 46
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
47
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
C
caoying03 已提交
48
  if (ids < N * D) {
C
caoying03 已提交
49
    int row_ids = ids / D;
C
caoying03 已提交
50
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
51 52
  }
}
53
}  // namespace
54 55 56

template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
57 58 59
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
60
                   "This kernel only runs on GPU device.");
C
caoying03 已提交
61 62 63
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* y = ctx.Output<Tensor>("Y");
64
    y->mutable_data<T>(ctx.GetPlace());
65

66 67
    math::CrossEntropyFunctor<platform::GPUPlace, T>()(
        ctx, y, x, label, ctx.Attr<bool>("softLabel"));
68 69 70 71
  }
};

template <typename T>
72
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
73 74 75
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
76
                   "This kernel only runs on GPU device.");
77

C
caoying03 已提交
78 79 80
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
81
    dx->mutable_data<T>(ctx.GetPlace());
82

C
caoying03 已提交
83 84 85 86
    const T* dy_data =
        ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    const T* x_data = x->data<T>();
87

C
caoying03 已提交
88 89
    int batch_size = x->dims()[0];
    int class_num = x->dims()[1];
C
caoying03 已提交
90

91
    int block = 512;
C
caoying03 已提交
92 93 94
    int grid = (batch_size * class_num + block - 1) / block;

    if (ctx.Attr<bool>("softLabel")) {
95
      auto* label_data = label->data<T>();
C
caoying03 已提交
96 97 98 99
      SoftCrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
C
caoying03 已提交
100
                                           batch_size, class_num);
101
    } else {
C
caoying03 已提交
102 103 104 105 106
      Zero<T><<<grid, block, 0,
                reinterpret_cast<const platform::CUDADeviceContext&>(
                    ctx.device_context())
                    .stream()>>>(dx_data, batch_size * class_num);

107
      auto* label_data = label->data<int>();
C
caoying03 已提交
108
      grid = (batch_size + block - 1) / block;
C
caoying03 已提交
109 110 111 112
      CrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
C
caoying03 已提交
113
                                           batch_size, class_num);
114
    }
115 116 117 118 119
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
120

D
dongzhihong 已提交
121
namespace ops = paddle::operators;
122 123 124
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpCUDAKernel<float>);