onehot_cross_entropy_op.cu 4.4 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15 16 17 18 19 20 21 22
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

23
template <typename T>
D
dangqingqing 已提交
24 25 26 27 28 29
__host__ __device__ T clipping_log(const T x) {
  PADDLE_ASSERT(std::is_floating_point<T>::value);
  const T kApproInf = 1e20;
  T v = log(x);
  if (v == INFINITY) {
    return kApproInf;
30
  }
D
dangqingqing 已提交
31 32 33 34 35
  if (v == -INFINITY) {
    return -kApproInf;
  }
  return v;
}
36

37 38 39 40 41 42 43 44
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
                                   const int N, const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
D
dangqingqing 已提交
45
    Y[i] = -clipping_log(X[i * D + label[i]]);
46 47 48
  }
}

49
// TODO(qingqing): make zero setting an common function.
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
template <typename T>
__global__ void zero(T* X, const int N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    X[i] = 0.0;
  }
}

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
                                           const int* label, const int N,
                                           const int D) {
  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
  // CUDA_1D_KERNEL_LOOP(i, N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    int idx = i * D + label[i];
    dX[idx] = -dY[i] / X[idx];
  }
}

template <typename T>
class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");

    auto X = ctx.Input<Tensor>("X");
    const T* Xdata = X->data<T>();
    const int* label_data = ctx.Input<Tensor>("label")->data<int>();
    auto Y = ctx.Output<Tensor>("Y");
    Y->mutable_data<T>(ctx.GetPlace());
    T* Ydata = Y->data<T>();

    int N = X->dims()[0];
    int D = X->dims()[1];
    int block = 512;
    int grid = (N + block - 1) / block;
    // TODO(qingqing) launch kernel on specified stream
    // base on ExecutionContext.
    CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);
  }
};

template <typename T>
class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");

    auto X = ctx.Input<Tensor>("X");
    auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto label = ctx.Input<Tensor>("label");

    auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
    auto* dYdata = dY->template data<T>();
    auto* Xdata = X->template data<T>();
    auto* label_data = label->data<int>();

    int N = X->dims()[0];
    int D = X->dims()[1];
    int block = 512;
    int grid = (N * D + block - 1) / block;
    zero<T><<<grid, block>>>(dXdata, N * D);

    grid = (N + block - 1) / block;
    // TODO(qingqing): launch kernel on specified stream
    // base on ExecutionContext.
    CrossEntropyGradientKernel<T><<<grid, block>>>(dXdata, dYdata, Xdata,
                                                   label_data, N, D);
  }
};

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
128

D
dongzhihong 已提交
129
namespace ops = paddle::operators;
130 131 132 133
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
                       ops::OnehotCrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad,
                       ops::OnehotCrossEntropyGradientOpCUDAKernel<float>);