cross_entropy.cu 3.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

S
sneaxiy 已提交
15
#include "paddle/fluid/operators/math.h"
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/cross_entropy.h"
17
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
19 20 21 22 23 24

namespace paddle {
namespace operators {
namespace math {

template <typename T>
25
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
26 27
                                   const int N, const int D,
                                   const int ignore_index) {
28 29
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
30 31
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index);
    Y[i] = ignore_index == label[i]
C
chengduo 已提交
32 33
               ? static_cast<T>(0)
               : -math::TolerableValue<T>()(real_log(X[i * D + label[i]]));
34 35 36 37 38 39 40
  }
}

template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
                                       const int class_num) {
  int tid = threadIdx.x;
C
chengduo 已提交
41
  T val(0);
42

43 44 45
  int idx = blockIdx.x * class_num + tid;
  int end = blockIdx.x * class_num + class_num;
  for (; idx < end; idx += blockDim.x) {
C
chengduo 已提交
46
    val += math::TolerableValue<T>()(real_log(X[idx])) * label[idx];
47 48
  }

49 50 51
  val = paddle::platform::reduceSum(val, tid, blockDim.x);
  if (threadIdx.x == 0) {
    Y[blockIdx.x] = -val;
52 53 54 55
  }
}

template <typename T>
Q
QI JUN 已提交
56
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
57
 public:
Q
QI JUN 已提交
58 59
  void operator()(const platform::CUDADeviceContext& ctx,
                  framework::Tensor* out, const framework::Tensor* prob,
60 61
                  const framework::Tensor* labels, const bool softLabel,
                  const int ignore_index, const int axis_dim) {
62 63 64 65 66 67 68 69
    const T* prob_data = prob->data<T>();
    T* loss_data = out->mutable_data<T>(ctx.GetPlace());

    int batch_size = prob->dims()[0];
    int class_num = prob->dims()[1];

    if (softLabel) {
      const T* label_data = labels->data<T>();
70 71 72
      int block = class_num > 512
                      ? 512
                      : pow(2, static_cast<int>(std::log2(class_num)));
73

74
      SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
Q
qijun 已提交
75
          loss_data, prob_data, label_data, class_num);
76
    } else {
77
      const int64_t* label_data = labels->data<int64_t>();
78 79
      int block = 512;
      int grid = (batch_size + block - 1) / block;
Q
QI JUN 已提交
80
      CrossEntropyKernel<T><<<grid, block, 0, ctx.stream()>>>(
81 82
          loss_data, prob_data, label_data, batch_size, class_num,
          ignore_index);
83 84 85 86
    }
  }
};

Q
QI JUN 已提交
87 88
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
C
chengduo 已提交
89 90
template class CrossEntropyFunctor<platform::CUDADeviceContext,
                                   platform::float16>;
91 92 93
}  // namespace math
}  // namespace operators
}  // namespace paddle