histogram_kernel.cc 3.0 KB
Newer Older
P
phlrain 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
P
phlrain 已提交
14 15 16 17

#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
P
phlrain 已提交
18
#include "paddle/pten/kernels/funcs/math_function.h"
P
phlrain 已提交
19 20 21 22 23

namespace pten {

template <typename T, typename Context>
void HistogramKernel(const Context& dev_ctx,
P
phlrain 已提交
24 25 26 27 28 29 30 31
                     const DenseTensor& input,
                     int64_t bins,
                     int min,
                     int max,
                     DenseTensor* output) {
  auto& nbins = bins;
  auto& minval = min;
  auto& maxval = max;
P
phlrain 已提交
32

P
phlrain 已提交
33 34
  const T* input_data = input.data<T>();
  auto input_numel = input.numel();
P
phlrain 已提交
35

P
phlrain 已提交
36 37 38
  int64_t* out_data = output->mutable_data<int64_t>(dev_ctx.GetPlace());
  pten::funcs::SetConstant<Context, int64_t>()(
      dev_ctx, output, static_cast<int64_t>(0));
P
phlrain 已提交
39

P
phlrain 已提交
40
  if (input_data == nullptr) return;
P
phlrain 已提交
41

P
phlrain 已提交
42 43 44 45 46 47 48 49 50 51
  T output_min = static_cast<T>(minval);
  T output_max = static_cast<T>(maxval);
  if (output_min == output_max) {
    output_min = *std::min_element(input_data, input_data + input_numel);
    output_max = *std::max_element(input_data, input_data + input_numel);
  }
  if (output_min == output_max) {
    output_min = output_min - 1;
    output_max = output_max + 1;
  }
P
phlrain 已提交
52

P
phlrain 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
  PADDLE_ENFORCE_EQ(
      (std::isinf(static_cast<float>(output_min)) ||
       std::isnan(static_cast<float>(output_max)) ||
       std::isinf(static_cast<float>(output_min)) ||
       std::isnan(static_cast<float>(output_max))),
      false,
      pten::errors::OutOfRange("range of min, max is not finite"));
  PADDLE_ENFORCE_GE(
      output_max,
      output_min,
      pten::errors::InvalidArgument(
          "max must be larger or equal to min. If min and max are both zero, "
          "the minimum and maximum values of the data are used. "
          "But received max is %d, min is %d",
          maxval,
          minval));
P
phlrain 已提交
69

P
phlrain 已提交
70 71 72 73 74
  for (int64_t i = 0; i < input_numel; i++) {
    if (input_data[i] >= output_min && input_data[i] <= output_max) {
      const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /
                                    (output_max - output_min));
      out_data[std::min(bin, nbins - 1)] += 1;
P
phlrain 已提交
75
    }
P
phlrain 已提交
76
  }
P
phlrain 已提交
77 78
}

P
phlrain 已提交
79
}  // namespace pten
P
phlrain 已提交
80 81 82 83 84 85 86 87

PT_REGISTER_KERNEL(histogram,
                   CPU,
                   ALL_LAYOUT,
                   pten::HistogramKernel,
                   float,
                   double,
                   int,
P
phlrain 已提交
88
                   int64_t) {}