histogram_kernel.cc 2.4 KB
Newer Older
P
phlrain 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/funcs/math_function.h"

#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"


namespace pten {

template <typename T, typename Context>
void HistogramKernel(const Context& dev_ctx,
                        const DenseTensor& input,
                        int64_t bins,
                        int min,
                        int max,
                        DenseTensor* output)
{
    auto& nbins = bins;
    auto& minval = min;
    auto& maxval = max;

    const T* input_data = input.data<T>();
    auto input_numel = input.numel();

    int64_t* out_data = output->mutable_data<int64_t>(dev_ctx.GetPlace());
    pten::funcs::SetConstant<Context, int64_t>()(
        dev_ctx, output,
        static_cast<int64_t>(0));

    if (input_data == nullptr) return;

    T output_min = static_cast<T>(minval);
    T output_max = static_cast<T>(maxval);
    if (output_min == output_max) {
      output_min = *std::min_element(input_data, input_data + input_numel);
      output_max = *std::max_element(input_data, input_data + input_numel);
    }
    if (output_min == output_max) {
      output_min = output_min - 1;
      output_max = output_max + 1;
    }

    PADDLE_ENFORCE_EQ(
        (std::isinf(static_cast<float>(output_min)) ||
         std::isnan(static_cast<float>(output_max)) ||
         std::isinf(static_cast<float>(output_min)) ||
         std::isnan(static_cast<float>(output_max))),
        false, pten::errors::OutOfRange("range of min, max is not finite"));
    PADDLE_ENFORCE_GE(
        output_max, output_min,
        pten::errors::InvalidArgument(
            "max must be larger or equal to min. If min and max are both zero, "
            "the minimum and maximum values of the data are used. "
            "But received max is %d, min is %d",
            maxval, minval));

    for (int64_t i = 0; i < input_numel; i++) {
      if (input_data[i] >= output_min && input_data[i] <= output_max) {
        const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /
                                      (output_max - output_min));
        out_data[std::min(bin, nbins - 1)] += 1;
      }
    }
}

}  // namspace pten


PT_REGISTER_KERNEL(histogram,
                   CPU,
                   ALL_LAYOUT,
                   pten::HistogramKernel,
                   float,
                   double,
                   int,
                   int64_t) {}