histogram_kernel.cu 5.0 KB
Newer Older
Q
Qi Li 已提交
1 2


P
phlrain 已提交
3 4
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/funcs/math_function.h"
Q
Qi Li 已提交
5

P
phlrain 已提交
6 7
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
Q
Qi Li 已提交
8

9 10
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
Q
Qi Li 已提交
11

P
phlrain 已提交
12 13 14 15
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"

namespace pten {
Q
Qi Li 已提交
16 17

using IndexType = int64_t;
P
phlrain 已提交
18
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
Q
Qi Li 已提交
19 20 21 22 23 24

inline int GET_BLOCKS(const int N) {
  return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}

template <typename T, typename IndexType>
25
__device__ static IndexType GetBin(T input_value, T min_value, T max_value,
Q
Qi Li 已提交
26
                                   int64_t nbins) {
27 28 29 30
  IndexType bin = static_cast<int>((input_value - min_value) * nbins /
                                   (max_value - min_value));
  IndexType output_index = bin < nbins - 1 ? bin : nbins - 1;
  return output_index;
Q
Qi Li 已提交
31 32 33
}

template <typename T, typename IndexType>
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
__global__ void KernelHistogram(const T* input, const int total_elements,
                                const int64_t nbins, const T min_value,
                                const T max_value, int64_t* output) {
  extern __shared__ int64_t buf_hist[];
  for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
    buf_hist[i] = 0;
  }
  __syncthreads();

  CUDA_KERNEL_LOOP(input_index, total_elements) {
    // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
    const auto input_value = input[input_index];
    if (input_value >= min_value && input_value <= max_value) {
      const IndexType output_index =
          GetBin<T, IndexType>(input_value, min_value, max_value, nbins);
      paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1);
Q
Qi Li 已提交
50 51
    }
  }
52 53 54 55 56
  __syncthreads();

  for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
    paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]);
  }
Q
Qi Li 已提交
57 58
}

P
phlrain 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
template <typename T, typename Context>
void HistogramKernel(const Context& dev_ctx,
                        const DenseTensor& input,
                        int64_t bins,
                        int min,
                        int max,
                        DenseTensor* output)
{
    auto& nbins = bins;
    auto& minval = min;
    auto& maxval = max;

    const T* input_data = input.data<T>();
    const int input_numel = input.numel();

    int64_t* out_data = output->mutable_data<int64_t>(dev_ctx.GetPlace());
    pten::funcs::SetConstant<Context, int64_t>()(
        dev_ctx, output,
77 78 79 80
        static_cast<int64_t>(0));

    if (input_data == nullptr) return;

Q
Qi Li 已提交
81 82 83 84
    T output_min = static_cast<T>(minval);
    T output_max = static_cast<T>(maxval);

    if (output_min == output_max) {
P
phlrain 已提交
85
      auto input_x = pten::EigenVector<T>::Flatten(input);
Q
Qi Li 已提交
86

P
phlrain 已提交
87
      DenseTensor input_min_t, input_max_t;
Q
Qi Li 已提交
88
      auto* input_min_data =
P
phlrain 已提交
89
          input_min_t.mutable_data<T>({1}, dev_ctx.GetPlace());
Q
Qi Li 已提交
90
      auto* input_max_data =
P
phlrain 已提交
91 92 93
          input_max_t.mutable_data<T>({1}, dev_ctx.GetPlace());
      auto input_min_scala = pten::EigenScalar<T>::From(input_min_t);
      auto input_max_scala = pten::EigenScalar<T>::From(input_max_t);
Q
Qi Li 已提交
94 95

      auto* place =
P
phlrain 已提交
96
          dev_ctx.eigen_device();
Q
Qi Li 已提交
97 98 99
      input_min_scala.device(*place) = input_x.minimum();
      input_max_scala.device(*place) = input_x.maximum();

P
phlrain 已提交
100 101
      DenseTensor input_min_cpu, input_max_cpu;
      paddle::framework::TensorCopySync(input_min_t, paddle::platform::CPUPlace(),
102
                                        &input_min_cpu);
P
phlrain 已提交
103
      paddle::framework::TensorCopySync(input_max_t, paddle::platform::CPUPlace(),
104
                                        &input_max_cpu);
Q
Qi Li 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118

      output_min = input_min_cpu.data<T>()[0];
      output_max = input_max_cpu.data<T>()[0];
    }
    if (output_min == output_max) {
      output_min = output_min - 1;
      output_max = output_max + 1;
    }

    PADDLE_ENFORCE_EQ(
        (std::isinf(static_cast<float>(output_min)) ||
         std::isnan(static_cast<float>(output_max)) ||
         std::isinf(static_cast<float>(output_min)) ||
         std::isnan(static_cast<float>(output_max))),
P
phlrain 已提交
119
        false, pten::errors::OutOfRange("range of min, max is not finite"));
Q
Qi Li 已提交
120 121
    PADDLE_ENFORCE_GE(
        output_max, output_min,
P
phlrain 已提交
122
        pten::errors::InvalidArgument(
Q
Qi Li 已提交
123 124 125 126 127 128
            "max must be larger or equal to min. If min and max are both zero, "
            "the minimum and maximum values of the data are used. "
            "But received max is %d, min is %d",
            maxval, minval));

    auto stream =
P
phlrain 已提交
129
        dev_ctx.stream();
130 131 132
    KernelHistogram<
        T, IndexType><<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS,
                        nbins * sizeof(int64_t), stream>>>(
Q
Qi Li 已提交
133
        input_data, input_numel, nbins, output_min, output_max, out_data);
P
phlrain 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146
}

} //namespace pten


PT_REGISTER_KERNEL(histogram,
                   GPU,
                   ALL_LAYOUT,
                   pten::HistogramKernel,
                   float,
                   double,
                   int,
                   int64_t) {}