cross_entropy_op.cu 5.9 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15 16
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
17
#include "paddle/platform/hostdevice.h"
18 19 20 21 22 23

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

24
template <typename T>
25
HOSTDEVICE T tolerable_value(const T x) {
D
dangqingqing 已提交
26 27
  PADDLE_ASSERT(std::is_floating_point<T>::value);
  const T kApproInf = 1e20;
28
  if (x == INFINITY) {
D
dangqingqing 已提交
29
    return kApproInf;
30
  }
31
  if (x == -INFINITY) {
D
dangqingqing 已提交
32 33
    return -kApproInf;
  }
34
  return x;
D
dangqingqing 已提交
35
}
36

37 38 39 40 41 42 43 44
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
                                   const int N, const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
45 46 47 48 49 50 51 52 53 54 55 56 57 58
    Y[i] = -tolerable_value(log(X[i * D + label[i]]));
  }
}

template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
                                       const int N, const int D) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    T sum = static_cast<T>(0);
    for (int j = 0; j < D; j++) {
      sum += label[i * D + j] * log(X[i * D + j]);
    }
    Y[i] = -tolerable_value(sum);
59 60 61
  }
}

62
// TODO(qingqing): make zero setting an common function.
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
template <typename T>
__global__ void zero(T* X, const int N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    X[i] = 0.0;
  }
}

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
85 86 87 88 89 90 91 92 93 94 95 96 97 98
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                               const T* label, const int N,
                                               const int D) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    for (int j = 0; j < D; ++j) {
      int idx = i * D + j;
      dX[idx] = -label[idx] * dY[i] / X[idx];
    }
  }
}

template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
99 100 101 102 103
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");

104 105 106 107 108 109 110
    auto x = ctx.Input<Tensor>("X");
    auto y = ctx.Output<Tensor>("Y");
    auto label = ctx.Input<Tensor>("Label");

    auto* x_data = x->data<T>();
    y->mutable_data<T>(ctx.GetPlace());
    auto* y_data = y->data<T>();
111

112 113
    int n = x->dims()[0];
    int d = x->dims()[1];
114
    int block = 512;
115
    int grid = (n + block - 1) / block;
116 117
    // TODO(qingqing) launch kernel on specified stream
    // base on ExecutionContext.
118 119 120 121 122 123 124 125 126 127 128
    int label_rank = label->dims().size();
    if (label_rank == 2) {
      // soft cross entropy
      auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
      SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
                                                 d);
    } else {
      // normal cross entropy
      auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
      CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
    }
129 130 131 132
  }
};

template <typename T>
133
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
134 135 136 137 138
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");

139 140 141 142
    auto x = ctx.Input<Tensor>("X");
    auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto label = ctx.Input<Tensor>("Label");
143

144 145 146
    auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    auto* dy_data = dy->data<T>();
    auto* x_data = x->data<T>();
147

148 149
    int n = x->dims()[0];
    int d = x->dims()[1];
150
    int block = 512;
151 152 153
    int grid = (n * d + block - 1) / block;
    zero<T><<<grid, block>>>(dx_data, n * d);
    grid = (n + block - 1) / block;
154 155
    // TODO(qingqing): launch kernel on specified stream
    // base on ExecutionContext.
156 157 158 159 160 161 162 163 164 165 166 167
    int label_rank = label->dims().size();
    if (label_rank == 2) {
      // soft cross entropy
      auto* label_data = label->data<T>();
      SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
          dx_data, dy_data, x_data, label_data, n, d);
    } else {
      // normal cross entropy
      auto* label_data = label->data<int>();
      CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
                                                     label_data, n, d);
    }
168 169 170 171 172
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
173

D
dongzhihong 已提交
174
namespace ops = paddle::operators;
175 176 177
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpCUDAKernel<float>);