cross_entropy.cu 5.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15
#include "paddle/fluid/operators/math/cross_entropy.h"
16
#include "paddle/fluid/framework/convert_utils.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/operators/math.h"
18 19
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
20
#include "paddle/phi/backends/gpu/gpu_context.h"
21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

26
template <typename T, typename LabelT>
27 28 29 30 31
__global__ void CrossEntropyKernel(T* Y,
                                   const T* X,
                                   const LabelT* label,
                                   const int N,
                                   const int D,
32
                                   const int ignore_index) {
33
  CUDA_KERNEL_LOOP(i, N) {
34 35
    auto lbl = static_cast<int64_t>(label[i]);
    PADDLE_ENFORCE(lbl >= 0 && lbl < D || lbl == ignore_index,
36 37
                   "The value of label[%d] expected >= 0 and < %ld, or == %ld, "
                   "but got %ld. Please check input value.",
38 39 40 41
                   i,
                   D,
                   ignore_index,
                   lbl);
42
    Y[i] = ignore_index == lbl
C
chengduo 已提交
43
               ? static_cast<T>(0)
44
               : -math::TolerableValue<T>()(real_log(X[i * D + lbl]));
45 46 47 48
  }
}

template <typename T>
49 50 51
__global__ void SoftCrossEntropyKernel(T* Y,
                                       const T* X,
                                       const T* label,
52 53
                                       const int class_num) {
  int tid = threadIdx.x;
C
chengduo 已提交
54
  T val(0);
55

56 57 58
  int idx = blockIdx.x * class_num + tid;
  int end = blockIdx.x * class_num + class_num;
  for (; idx < end; idx += blockDim.x) {
C
chengduo 已提交
59
    val += math::TolerableValue<T>()(real_log(X[idx])) * label[idx];
60 61
  }

62 63 64
  val = paddle::platform::reduceSum(val, tid, blockDim.x);
  if (threadIdx.x == 0) {
    Y[blockIdx.x] = -val;
65 66 67
  }
}

68 69 70
template <typename T>
struct HardLabelCrossEntropyCUDAFunctorImpl {
 public:
71 72
  HardLabelCrossEntropyCUDAFunctorImpl(T* loss_data,
                                       const T* prob_data,
73 74 75 76
                                       const void* label_data,
                                       const int batch_size,
                                       const int class_num,
                                       const int ignore_index,
77 78
                                       const int block_size,
                                       gpuStream_t stream)
79 80 81 82 83 84 85 86 87 88 89 90 91
      : loss_data_(loss_data),
        prob_data_(prob_data),
        label_data_(label_data),
        batch_size_(batch_size),
        class_num_(class_num),
        ignore_index_(ignore_index),
        block_size_(block_size),
        stream_(stream) {}

  template <typename U>
  void apply() const {
    int grid_size = (batch_size_ + block_size_ - 1) / block_size_;
    CrossEntropyKernel<T, U><<<grid_size, block_size_, 0, stream_>>>(
92 93 94 95 96 97
        loss_data_,
        prob_data_,
        static_cast<const U*>(label_data_),
        batch_size_,
        class_num_,
        ignore_index_);
98 99 100 101 102 103 104 105 106 107 108 109 110
  }

 private:
  T* loss_data_;
  const T* prob_data_;
  const void* label_data_;
  const int batch_size_;
  const int class_num_;
  const int ignore_index_;
  const int block_size_;
  gpuStream_t stream_;
};

111 112
template <typename DeviceContext, typename T>
void CrossEntropyFunctor<DeviceContext, T>::operator()(
113 114 115 116 117 118 119
    const DeviceContext& ctx,
    framework::Tensor* out,
    const framework::Tensor* prob,
    const framework::Tensor* labels,
    const bool softLabel,
    const int ignore_index,
    const int axis_dim) {
120 121 122 123 124
  const T* prob_data = prob->data<T>();
  T* loss_data = out->mutable_data<T>(ctx.GetPlace());

  int batch_size = prob->dims()[0];
  int class_num = prob->dims()[1];
125
#ifdef __HIPCC__
126
  constexpr int kMaxBlockDim = 256;
127
#else
128
  constexpr int kMaxBlockDim = 512;
129
#endif
130

131 132 133 134 135 136 137 138 139
  if (softLabel) {
    const T* label_data = labels->data<T>();
    int block = class_num > kMaxBlockDim
                    ? kMaxBlockDim
                    : pow(2, static_cast<int>(std::log2(class_num)));

    SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
        loss_data, prob_data, label_data, class_num);
  } else {
140 141 142 143 144 145 146 147
    HardLabelCrossEntropyCUDAFunctorImpl<T> functor(loss_data,
                                                    prob_data,
                                                    labels->data(),
                                                    batch_size,
                                                    class_num,
                                                    ignore_index,
                                                    kMaxBlockDim,
                                                    ctx.stream());
148 149
    framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()),
                             functor);
150
  }
151
}
152

153 154 155 156
template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, platform::float16>;

157 158 159
}  // namespace math
}  // namespace operators
}  // namespace paddle