histogram_op.cu 5.9 KB
Newer Older
Q
Qi Li 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/histogram_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
namespace operators {

using IndexType = int64_t;
using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;

inline int GET_BLOCKS(const int N) {
  return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}

template <typename T, typename IndexType>
33
__device__ static IndexType GetBin(T input_value, T min_value, T max_value,
Q
Qi Li 已提交
34
                                   int64_t nbins) {
35 36 37 38
  IndexType bin = static_cast<int>((input_value - min_value) * nbins /
                                   (max_value - min_value));
  IndexType output_index = bin < nbins - 1 ? bin : nbins - 1;
  return output_index;
Q
Qi Li 已提交
39 40 41
}

template <typename T, typename IndexType>
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
__global__ void KernelHistogram(const T* input, const int total_elements,
                                const int64_t nbins, const T min_value,
                                const T max_value, int64_t* output) {
  extern __shared__ int64_t buf_hist[];
  for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
    buf_hist[i] = 0;
  }
  __syncthreads();

  CUDA_KERNEL_LOOP(input_index, total_elements) {
    // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
    const auto input_value = input[input_index];
    if (input_value >= min_value && input_value <= max_value) {
      const IndexType output_index =
          GetBin<T, IndexType>(input_value, min_value, max_value, nbins);
      paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1);
Q
Qi Li 已提交
58 59
    }
  }
60 61 62 63 64
  __syncthreads();

  for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
    paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]);
  }
Q
Qi Li 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
}

template <typename DeviceContext, typename T>
class HistogramCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));

    const Tensor* input = context.Input<framework::Tensor>("X");
    Tensor* output = context.Output<framework::Tensor>("Out");
    auto& nbins = context.Attr<int64_t>("bins");
    auto& minval = context.Attr<int>("min");
    auto& maxval = context.Attr<int>("max");

    const T* input_data = input->data<T>();
    const int input_numel = input->numel();

84 85 86 87 88 89 90
    int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
    math::SetConstant<platform::CUDADeviceContext, int64_t>()(
        context.template device_context<platform::CUDADeviceContext>(), output,
        static_cast<int64_t>(0));

    if (input_data == nullptr) return;

Q
Qi Li 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    T output_min = static_cast<T>(minval);
    T output_max = static_cast<T>(maxval);

    if (output_min == output_max) {
      auto input_x = framework::EigenVector<T>::Flatten(*input);

      framework::Tensor input_min_t, input_max_t;
      auto* input_min_data =
          input_min_t.mutable_data<T>({1}, context.GetPlace());
      auto* input_max_data =
          input_max_t.mutable_data<T>({1}, context.GetPlace());
      auto input_min_scala = framework::EigenScalar<T>::From(input_min_t);
      auto input_max_scala = framework::EigenScalar<T>::From(input_max_t);

      auto* place =
          context.template device_context<DeviceContext>().eigen_device();
      input_min_scala.device(*place) = input_x.minimum();
      input_max_scala.device(*place) = input_x.maximum();

      Tensor input_min_cpu, input_max_cpu;
      TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);
      TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);

      output_min = input_min_cpu.data<T>()[0];
      output_max = input_max_cpu.data<T>()[0];
    }
    if (output_min == output_max) {
      output_min = output_min - 1;
      output_max = output_max + 1;
    }

    PADDLE_ENFORCE_EQ(
        (std::isinf(static_cast<float>(output_min)) ||
         std::isnan(static_cast<float>(output_max)) ||
         std::isinf(static_cast<float>(output_min)) ||
         std::isnan(static_cast<float>(output_max))),
        false, platform::errors::OutOfRange("range of min, max is not finite"));
    PADDLE_ENFORCE_GE(
        output_max, output_min,
        platform::errors::InvalidArgument(
            "max must be larger or equal to min. If min and max are both zero, "
            "the minimum and maximum values of the data are used. "
            "But received max is %d, min is %d",
            maxval, minval));

    auto stream =
        context.template device_context<platform::CUDADeviceContext>().stream();
138 139 140
    KernelHistogram<
        T, IndexType><<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS,
                        nbins * sizeof(int64_t), stream>>>(
Q
Qi Li 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154
        input_data, input_numel, nbins, output_min, output_max, out_data);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    histogram,
    ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, double>);