cross_entropy.cu 3.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/cross_entropy.h"
16
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
17
#include "paddle/fluid/platform/cuda_primitives.h"
18
#include "paddle/fluid/platform/float16.h"
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24 25 26 27 28 29 30 31 32 33 34 35 36
template <typename T>
HOSTDEVICE T log(const T& val) {
  return std::log(val);
}

template <>
HOSTDEVICE platform::float16 log(const platform::float16& val) {
  // strage bug, hlog is not exists.
  return static_cast<float16>(0);
  // half tmp = static_cast<half>(val);
  // return static_cast<platform::float16>(hlog(tmp));
}

37 38
namespace {
template <typename T>
39
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
40 41 42 43 44 45 46 47 48 49 50 51
                                   const int N, const int D) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
    Y[i] = -math::TolerableValue<T>()(log(X[i * D + label[i]]));
  }
}

template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
                                       const int class_num) {
  int tid = threadIdx.x;
52
  T val(0);
53

54 55 56
  int idx = blockIdx.x * class_num + tid;
  int end = blockIdx.x * class_num + class_num;
  for (; idx < end; idx += blockDim.x) {
57
    val += math::TolerableValue<T>()(log(X[idx])) * label[idx];
58 59
  }

60 61 62
  val = paddle::platform::reduceSum(val, tid, blockDim.x);
  if (threadIdx.x == 0) {
    Y[blockIdx.x] = -val;
63 64 65 66 67 68 69
  }
}
}  // namespace

using Tensor = framework::Tensor;

template <typename T>
Q
QI JUN 已提交
70
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
71
 public:
Q
QI JUN 已提交
72 73
  void operator()(const platform::CUDADeviceContext& ctx,
                  framework::Tensor* out, const framework::Tensor* prob,
74 75 76 77 78 79 80 81 82
                  const framework::Tensor* labels, bool softLabel) {
    const T* prob_data = prob->data<T>();
    T* loss_data = out->mutable_data<T>(ctx.GetPlace());

    int batch_size = prob->dims()[0];
    int class_num = prob->dims()[1];

    if (softLabel) {
      const T* label_data = labels->data<T>();
83 84 85
      int block = class_num > 512
                      ? 512
                      : pow(2, static_cast<int>(std::log2(class_num)));
86

87
      SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
Q
qijun 已提交
88
          loss_data, prob_data, label_data, class_num);
89
    } else {
90
      const int64_t* label_data = labels->data<int64_t>();
91 92
      int block = 512;
      int grid = (batch_size + block - 1) / block;
Q
QI JUN 已提交
93
      CrossEntropyKernel<T><<<grid, block, 0, ctx.stream()>>>(
Q
qijun 已提交
94
          loss_data, prob_data, label_data, batch_size, class_num);
95 96 97 98
    }
  }
};

Q
QI JUN 已提交
99 100
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
101 102
template class CrossEntropyFunctor<platform::CUDADeviceContext,
                                   platform::float16>;
103 104 105
}  // namespace math
}  // namespace operators
}  // namespace paddle