#include "paddle/pten/kernels/histogram_kernel.h" #include "paddle/pten/kernels/funcs/math_function.h" #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/kernel_registry.h" namespace pten { template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, int64_t bins, int min, int max, DenseTensor* output) { auto& nbins = bins; auto& minval = min; auto& maxval = max; const T* input_data = input.data(); auto input_numel = input.numel(); int64_t* out_data = output->mutable_data(dev_ctx.GetPlace()); pten::funcs::SetConstant()( dev_ctx, output, static_cast(0)); if (input_data == nullptr) return; T output_min = static_cast(minval); T output_max = static_cast(maxval); if (output_min == output_max) { output_min = *std::min_element(input_data, input_data + input_numel); output_max = *std::max_element(input_data, input_data + input_numel); } if (output_min == output_max) { output_min = output_min - 1; output_max = output_max + 1; } PADDLE_ENFORCE_EQ( (std::isinf(static_cast(output_min)) || std::isnan(static_cast(output_max)) || std::isinf(static_cast(output_min)) || std::isnan(static_cast(output_max))), false, pten::errors::OutOfRange("range of min, max is not finite")); PADDLE_ENFORCE_GE( output_max, output_min, pten::errors::InvalidArgument( "max must be larger or equal to min. If min and max are both zero, " "the minimum and maximum values of the data are used. " "But received max is %d, min is %d", maxval, minval)); for (int64_t i = 0; i < input_numel; i++) { if (input_data[i] >= output_min && input_data[i] <= output_max) { const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / (output_max - output_min)); out_data[std::min(bin, nbins - 1)] += 1; } } } } // namspace pten PT_REGISTER_KERNEL(histogram, CPU, ALL_LAYOUT, pten::HistogramKernel, float, double, int, int64_t) {}