cross_entropy_op.cu 6.0 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/op_registry.h"
16
#include "paddle/operators/cross_entropy_op.h"
17
#include "paddle/platform/assert.h"
18
#include "paddle/platform/hostdevice.h"
19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
                                   const int N, const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
C
caoying03 已提交
31
    Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
32 33 34
  }
}

C
caoying03 已提交
35
template <typename T, int BlockSize>
36 37
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
                                       const int N, const int D) {
38
  int tid = threadIdx.x;
C
caoying03 已提交
39
  __shared__ T d_sum[BlockSize];
40 41 42 43 44
  int next_idx = blockIdx.x * D + tid;

  d_sum[tid] = 0;
  int cur_idx = tid;
  while (cur_idx < D) {
C
caoying03 已提交
45 46 47
    d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
    next_idx += BlockSize;
    cur_idx += BlockSize;
48 49 50
  }
  __syncthreads();

C
caoying03 已提交
51
  for (int stride = BlockSize >> 1; stride > 0; stride >>= 1) {
52 53 54 55
    __syncthreads();
    if (tid < stride) {
      next_idx = tid + stride;
      d_sum[tid] += d_sum[next_idx];
56
    }
57 58 59 60 61
  }
  __syncthreads();

  if (tid == 0) {
    Y[blockIdx.x] = -d_sum[0];
62 63 64
  }
}

65
// TODO(qingqing): make zero setting an common function.
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
template <typename T>
__global__ void zero(T* X, const int N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    X[i] = 0.0;
  }
}

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
88 89 90
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
C
caoying03 已提交
91 92 93 94 95 96
  int row_ids = blockIdx.x * blockDim.x + threadIdx.x;
  int col_ids = blockIdx.y * blockDim.y + threadIdx.y;
  int ids = row_ids * D + col_ids;

  if (ids < N * D) {
    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
97 98 99 100 101
  }
}

template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
102 103 104
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
105
                   "This kernel only runs on GPU device.");
106

107 108 109 110 111 112 113
    auto x = ctx.Input<Tensor>("X");
    auto y = ctx.Output<Tensor>("Y");
    auto label = ctx.Input<Tensor>("Label");

    auto* x_data = x->data<T>();
    y->mutable_data<T>(ctx.GetPlace());
    auto* y_data = y->data<T>();
114

115 116
    int n = x->dims()[0];
    int d = x->dims()[1];
117
    int block = 512;
118
    int grid = (n + block - 1) / block;
119 120
    // TODO(qingqing) launch kernel on specified stream
    // base on ExecutionContext.
121
    if (ctx.Attr<int>("soft_label") == 1) {
122
      auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
123 124 125
      grid = d;
      SoftCrossEntropyKernel<T, 512><<<grid, block>>>(y_data, x_data,
                                                      label_data, n, d);
126 127 128 129
    } else {
      auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
      CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
    }
130 131 132 133
  }
};

template <typename T>
134
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
135 136 137
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
C
caoying03 已提交
138
                   "This kernel only runs on GPU device.");
139

140 141 142 143
    auto x = ctx.Input<Tensor>("X");
    auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto label = ctx.Input<Tensor>("Label");
144

145 146 147
    auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    auto* dy_data = dy->data<T>();
    auto* x_data = x->data<T>();
148

149 150
    int n = x->dims()[0];
    int d = x->dims()[1];
151
    int block = 512;
152 153 154
    int grid = (n * d + block - 1) / block;
    zero<T><<<grid, block>>>(dx_data, n * d);
    grid = (n + block - 1) / block;
155 156
    // TODO(qingqing): launch kernel on specified stream
    // base on ExecutionContext.
157
    if (ctx.Attr<int>("soft_label") == 1) {
C
caoying03 已提交
158 159 160 161 162
      int block_x = 32;
      int block_y = 32;
      dim3 block(block_x, block_y);
      dim3 grid((n + block_x - 1) / block_x, (d + block_y - 1) / block_y);

163 164 165 166 167 168 169 170
      auto* label_data = label->data<T>();
      SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
          dx_data, dy_data, x_data, label_data, n, d);
    } else {
      auto* label_data = label->data<int>();
      CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
                                                     label_data, n, d);
    }
171 172 173 174 175
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
176

D
dongzhihong 已提交
177
namespace ops = paddle::operators;
178 179 180
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpCUDAKernel<float>);