cross_entropy_op.cu 7.9 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/op_registry.h"
16
#include "paddle/operators/cross_entropy_op.h"
17
#include "paddle/platform/assert.h"
18
#include "paddle/platform/hostdevice.h"
19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
                                   const int N, const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
C
caoying03 已提交
31
    Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
32 33 34
  }
}

C
caoying03 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
  val += __shfl_down(val, 16);
  val += __shfl_down(val, 8);
  val += __shfl_down(val, 4);
  val += __shfl_down(val, 2);
  val += __shfl_down(val, 1);
  return val;
}

// This kernel is called when the class number is less than or equal to 512.
template <typename T>
__global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
                                        const int class_num) {
  int tid = threadIdx.x;
  extern __shared__ T d_sum[];
  d_sum[tid] = 0;

  int cur_idx = tid;
  int next_idx = blockIdx.x * class_num + tid;
  while (cur_idx < class_num) {
    d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
    next_idx += blockDim.x;
    cur_idx += blockDim.x;
  }
  __syncthreads();

  for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
    if (tid < stride) d_sum[tid] += d_sum[tid + stride];
    __syncthreads();
  }

  T val = d_sum[tid];
  val = sum_single_warp<T>(val);
  if (tid == 0) Y[blockIdx.x] = -val;
}

// This kernel is called when the class number is larger than 512.
C
caoying03 已提交
73
template <typename T, int BlockSize>
C
caoying03 已提交
74 75
__global__ void SoftCrossEntropyKernel2(T* Y, const T* X, const T* label,
                                        const int class_num) {
76
  int tid = threadIdx.x;
C
caoying03 已提交
77
  __shared__ T d_sum[BlockSize];
C
caoying03 已提交
78
  int next_idx = blockIdx.x * class_num + tid;
79 80 81

  d_sum[tid] = 0;
  int cur_idx = tid;
C
caoying03 已提交
82
  while (cur_idx < class_num) {
C
caoying03 已提交
83 84 85
    d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
    next_idx += BlockSize;
    cur_idx += BlockSize;
86 87 88
  }
  __syncthreads();

C
caoying03 已提交
89 90
  for (unsigned int stride = BlockSize >> 1; stride >= 32; stride >>= 1) {
    if (tid < stride) d_sum[tid] += d_sum[tid + stride];
91 92 93
    __syncthreads();
  }

C
caoying03 已提交
94 95 96
  T val = d_sum[tid];
  val = sum_single_warp<T>(val);
  if (tid == 0) Y[blockIdx.x] = -val;
97 98
}

C
caoying03 已提交
99
// TODO(qingqing): make zero setting a common function.
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
template <typename T>
__global__ void zero(T* X, const int N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    X[i] = 0.0;
  }
}

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
122 123 124
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
125
  int ids = blockIdx.x * blockDim.x + threadIdx.x;
C
caoying03 已提交
126
  if (ids < N * D) {
C
caoying03 已提交
127
    int row_ids = ids / D;
C
caoying03 已提交
128
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
129 130 131 132 133
  }
}

template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
134 135 136
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
137
                   "This kernel only runs on GPU device.");
138

139 140 141 142 143 144 145
    auto x = ctx.Input<Tensor>("X");
    auto y = ctx.Output<Tensor>("Y");
    auto label = ctx.Input<Tensor>("Label");

    auto* x_data = x->data<T>();
    y->mutable_data<T>(ctx.GetPlace());
    auto* y_data = y->data<T>();
146

C
caoying03 已提交
147 148
    int batch_size = x->dims()[0];
    int class_num = x->dims()[1];
149
    int block = 512;
C
caoying03 已提交
150

151
    if (ctx.Attr<bool>("soft_label")) {
152
      auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
C
caoying03 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166
      if (class_num > 512) {
        SoftCrossEntropyKernel2<
            T, 512><<<batch_size, block, 0,
                      reinterpret_cast<const platform::CUDADeviceContext&>(
                          ctx.device_context())
                          .stream()>>>(y_data, x_data, label_data, class_num);
      } else {
        int block_size = pow(2, int(std::log2(class_num)));
        SoftCrossEntropyKernel1<
            T><<<batch_size, block_size, block_size * sizeof(T),
                 reinterpret_cast<const platform::CUDADeviceContext&>(
                     ctx.device_context())
                     .stream()>>>(y_data, x_data, label_data, class_num);
      }
167 168
    } else {
      auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
C
caoying03 已提交
169 170 171 172 173 174
      int grid = (batch_size + block - 1) / block;
      CrossEntropyKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(y_data, x_data, label_data,
                                           batch_size, class_num);
175
    }
176 177 178 179
  }
};

template <typename T>
180
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
181 182 183
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
184
                   "This kernel only runs on GPU device.");
185

186 187 188 189
    auto x = ctx.Input<Tensor>("X");
    auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto label = ctx.Input<Tensor>("Label");
190

191 192 193
    auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    auto* dy_data = dy->data<T>();
    auto* x_data = x->data<T>();
194

195 196
    int n = x->dims()[0];
    int d = x->dims()[1];
C
caoying03 已提交
197

198
    int block = 512;
199
    int grid = (n * d + block - 1) / block;
C
caoying03 已提交
200 201 202 203
    zero<T><<<grid, block, 0,
              reinterpret_cast<const platform::CUDADeviceContext&>(
                  ctx.device_context())
                  .stream()>>>(dx_data, n * d);
204
    if (ctx.Attr<bool>("soft_label")) {
205
      auto* label_data = label->data<T>();
C
caoying03 已提交
206 207 208 209 210
      SoftCrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
                                           n, d);
211 212
    } else {
      auto* label_data = label->data<int>();
C
caoying03 已提交
213 214 215 216 217
      CrossEntropyGradientKernel<T><<<
          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                              ctx.device_context())
                              .stream()>>>(dx_data, dy_data, x_data, label_data,
                                           n, d);
218
    }
219 220 221 222 223
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
224

D
dongzhihong 已提交
225
namespace ops = paddle::operators;
226 227 228
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpCUDAKernel<float>);