cross_entropy.cu 4.0 KB
Newer Older
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/cross_entropy.h"
16 17 18 19 20 21 22

namespace paddle {
namespace operators {
namespace math {

namespace {
template <typename T>
23
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
                                   const int N, const int D) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
       i += blockDim.x * gridDim.x) {
    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
    Y[i] = -math::TolerableValue<T>()(log(X[i * D + label[i]]));
  }
}

template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
  val += __shfl_down(val, 16);
  val += __shfl_down(val, 8);
  val += __shfl_down(val, 4);
  val += __shfl_down(val, 2);
  val += __shfl_down(val, 1);
  return val;
}

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template <typename T>
struct SharedMemory {
  // Ensure that we won't compile any un-specialized types
  __device__ T* GetPointer() { return NULL; }
};

template <>
struct SharedMemory<float> {
  __device__ float* GetPointer() {
    extern __shared__ float s_float[];
    return s_float;
  }
};

template <>
struct SharedMemory<double> {
  __device__ double* GetPointer() {
    extern __shared__ double s_double[];
    return s_double;
  }
};

66 67 68 69
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
                                       const int class_num) {
  int tid = threadIdx.x;
70 71
  SharedMemory<T> d_sum_shared;
  T* d_sum = d_sum_shared.GetPointer();
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
  d_sum[tid] = 0;

  int cur_idx = tid;
  int next_idx = blockIdx.x * class_num + tid;
  while (cur_idx < class_num) {
    d_sum[tid] +=
        math::TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
    next_idx += blockDim.x;
    cur_idx += blockDim.x;
  }
  __syncthreads();

  for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
    if (tid < stride) d_sum[tid] += d_sum[tid + stride];
    __syncthreads();
  }

  T val = d_sum[tid];
  val = sum_single_warp<T>(val);
  if (tid == 0) Y[blockIdx.x] = -val;
}
}  // namespace

using Tensor = framework::Tensor;

template <typename T>
Q
QI JUN 已提交
98
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
99
 public:
Q
QI JUN 已提交
100 101
  void operator()(const platform::CUDADeviceContext& ctx,
                  framework::Tensor* out, const framework::Tensor* prob,
102 103 104 105 106 107 108 109 110 111 112
                  const framework::Tensor* labels, bool softLabel) {
    const T* prob_data = prob->data<T>();
    T* loss_data = out->mutable_data<T>(ctx.GetPlace());

    int batch_size = prob->dims()[0];
    int class_num = prob->dims()[1];

    if (softLabel) {
      const T* label_data = labels->data<T>();
      int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));

Q
qijun 已提交
113 114 115 116
      SoftCrossEntropyKernel<T><<<
          batch_size, block, block * sizeof(T),
          reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
          loss_data, prob_data, label_data, class_num);
117
    } else {
118
      const int64_t* label_data = labels->data<int64_t>();
119 120
      int block = 512;
      int grid = (batch_size + block - 1) / block;
Q
QI JUN 已提交
121
      CrossEntropyKernel<T><<<grid, block, 0, ctx.stream()>>>(
Q
qijun 已提交
122
          loss_data, prob_data, label_data, batch_size, class_num);
123 124 125 126
    }
  }
};

Q
QI JUN 已提交
127 128
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
129 130 131
}  // namespace math
}  // namespace operators
}  // namespace paddle