cross_entropy_op.cu 6.9 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/op_registry.h"
16
#include "paddle/operators/cross_entropy_op.h"
17
#include "paddle/platform/assert.h"
18
#include "paddle/platform/hostdevice.h"
19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
                                   const int N, const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
C
caoying03 已提交
31
    Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
32 33 34
  }
}

C
caoying03 已提交
35 36 37 38 39 40 41 42 43 44 45
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
  val += __shfl_down(val, 16);
  val += __shfl_down(val, 8);
  val += __shfl_down(val, 4);
  val += __shfl_down(val, 2);
  val += __shfl_down(val, 1);
  return val;
}

template <typename T>
C
caoying03 已提交
46 47
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
                                       const int class_num) {
C
caoying03 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  int tid = threadIdx.x;
  extern __shared__ T d_sum[];
  d_sum[tid] = 0;

  int cur_idx = tid;
  int next_idx = blockIdx.x * class_num + tid;
  while (cur_idx < class_num) {
    d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
    next_idx += blockDim.x;
    cur_idx += blockDim.x;
  }
  __syncthreads();

  for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
    if (tid < stride) d_sum[tid] += d_sum[tid + stride];
    __syncthreads();
  }

  T val = d_sum[tid];
  val = sum_single_warp<T>(val);
  if (tid == 0) Y[blockIdx.x] = -val;
}

// TODO(qingqing): make zero setting a common function.
72
template <typename T>
C
caoying03 已提交
73
__global__ void Zero(T* X, const int N) {
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    X[i] = 0.0;
  }
}

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
94 95 96
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
97
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
C
caoying03 已提交
98
  if (ids < N * D) {
C
caoying03 已提交
99
    int row_ids = ids / D;
C
caoying03 已提交
100
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
101 102 103 104 105
  }
}

template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
106 107 108
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
109
                   "This kernel only runs on GPU device.");
110

C
caoying03 已提交
111 112 113
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* y = ctx.Output<Tensor>("Y");
114

C
caoying03 已提交
115 116
    const T* x_data = x->data<T>();
    T* y_data = y->mutable_data<T>(ctx.GetPlace());
117

C
caoying03 已提交
118 119 120
    int batch_size = x->dims()[0];
    int class_num = x->dims()[1];

C
caoying03 已提交
121
    if (ctx.Attr<bool>("softLabel")) {
122
      auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
C
caoying03 已提交
123 124 125 126 127 128 129
      int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));

      SoftCrossEntropyKernel<
          T><<<batch_size, block, block * sizeof(T),
               reinterpret_cast<const platform::CUDADeviceContext&>(
                   ctx.device_context())
                   .stream()>>>(y_data, x_data, label_data, class_num);
130 131
    } else {
      auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
C
caoying03 已提交
132
      int block = 512;
C
caoying03 已提交
133 134 135 136 137 138
      int grid = (batch_size + block - 1) / block;
      CrossEntropyKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(y_data, x_data, label_data,
                                           batch_size, class_num);
139
    }
140 141 142 143
  }
};

template <typename T>
144
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
145 146 147
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
148
                   "This kernel only runs on GPU device.");
149

C
caoying03 已提交
150 151 152
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* label = ctx.Input<Tensor>("Label");
    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
153

C
caoying03 已提交
154 155 156 157
    const T* dy_data =
        ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    const T* x_data = x->data<T>();
158

C
caoying03 已提交
159 160
    int batch_size = x->dims()[0];
    int class_num = x->dims()[1];
C
caoying03 已提交
161

162
    int block = 512;
C
caoying03 已提交
163 164 165
    int grid = (batch_size * class_num + block - 1) / block;

    if (ctx.Attr<bool>("softLabel")) {
166
      auto* label_data = label->data<T>();
C
caoying03 已提交
167 168 169 170
      SoftCrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
C
caoying03 已提交
171
                                           batch_size, class_num);
172
    } else {
C
caoying03 已提交
173 174 175 176 177
      Zero<T><<<grid, block, 0,
                reinterpret_cast<const platform::CUDADeviceContext&>(
                    ctx.device_context())
                    .stream()>>>(dx_data, batch_size * class_num);

178
      auto* label_data = label->data<int>();
C
caoying03 已提交
179
      grid = (batch_size + block - 1) / block;
C
caoying03 已提交
180 181 182 183
      CrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
C
caoying03 已提交
184
                                           batch_size, class_num);
185
    }
186 187 188 189 190
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
191

D
dongzhihong 已提交
192
namespace ops = paddle::operators;
193 194 195
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpCUDAKernel<float>);