From 704cad6a66be50e7a81f418650805a1051d6a62c Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 2 Jun 2020 19:55:56 +0800 Subject: [PATCH] Add histc op (#24562) * add histc operator, test=develop * update english doc to 2.0 API, test=develop * update API from histc to histogram, test=develop Co-authored-by: root --- paddle/fluid/operators/histogram_op.cc | 92 +++++++++++ paddle/fluid/operators/histogram_op.cu | 147 ++++++++++++++++++ paddle/fluid/operators/histogram_op.h | 82 ++++++++++ python/paddle/__init__.py | 1 + .../tests/unittests/test_histogram_op.py | 87 +++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/linalg.py | 63 +++++++- 7 files changed, 472 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/histogram_op.cc create mode 100644 paddle/fluid/operators/histogram_op.cu create mode 100644 paddle/fluid/operators/histogram_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_histogram_op.py diff --git a/paddle/fluid/operators/histogram_op.cc b/paddle/fluid/operators/histogram_op.cc new file mode 100644 index 00000000000..32cc38ef195 --- /dev/null +++ b/paddle/fluid/operators/histogram_op.cc @@ -0,0 +1,92 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/histogram_op.h" + +#include +#include +#include + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class HistogramOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "histogram"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "histogram"); + const auto &nbins = ctx->Attrs().Get("bins"); + const auto &minval = ctx->Attrs().Get("min"); + const auto &maxval = ctx->Attrs().Get("max"); + + PADDLE_ENFORCE_GE(nbins, 1, + platform::errors::InvalidArgument( + "The bins should be greater than or equal to 1." + "But received nbins is %d", + nbins)); + PADDLE_ENFORCE_GE(maxval, minval, platform::errors::InvalidArgument( + "max must be larger or equal to min." + "But received max is %d, min is %d", + maxval, minval)); + + ctx->SetOutputDim("Out", framework::make_ddim({nbins})); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class HistogramOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input tensor of Histogram op,"); + AddOutput("Out", "(Tensor) The output tensor of Histogram op,"); + AddAttr("bins", "(int) number of histogram bins") + .SetDefault(100) + .EqualGreaterThan(1); + AddAttr("min", "(int) lower end of the range (inclusive)") + .SetDefault(0); + AddAttr("max", "(int) upper end of the range (inclusive)") + .SetDefault(0); + AddComment(R"DOC( + Histogram Operator. + Computes the histogram of a tensor. The elements are sorted + into equal width bins between min and max. If min and max are + both zero, the minimum and maximum values of the data are used. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + histogram, ops::HistogramOp, ops::HistogramOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + histogram, ops::HistogramKernel, + ops::HistogramKernel, + ops::HistogramKernel, + ops::HistogramKernel); diff --git a/paddle/fluid/operators/histogram_op.cu b/paddle/fluid/operators/histogram_op.cu new file mode 100644 index 00000000000..359e90bfc3a --- /dev/null +++ b/paddle/fluid/operators/histogram_op.cu @@ -0,0 +1,147 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/histogram_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using IndexType = int64_t; +using Tensor = framework::Tensor; +using platform::PADDLE_CUDA_NUM_THREADS; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__device__ static IndexType GetBin(T bVal, T minvalue, T maxvalue, + int64_t nbins) { + IndexType bin = + static_cast((bVal - minvalue) * nbins / (maxvalue - minvalue)); + if (bin == nbins) bin -= 1; + return bin; +} + +template +__global__ void KernelHistogram(const T* input, const int totalElements, + const int64_t nbins, const T minvalue, + const T maxvalue, int64_t* output) { + CUDA_KERNEL_LOOP(linearIndex, totalElements) { + const IndexType inputIdx = threadIdx.x + blockIdx.x * blockDim.x; + const auto inputVal = input[inputIdx]; + if (inputVal >= minvalue && inputVal <= maxvalue) { + const IndexType bin = + GetBin(inputVal, minvalue, maxvalue, nbins); + const IndexType outputIdx = bin < nbins - 1 ? bin : nbins - 1; + paddle::platform::CudaAtomicAdd(&output[outputIdx], 1); + } + } +} + +template +class HistogramCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(context.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); + + const Tensor* input = context.Input("X"); + Tensor* output = context.Output("Out"); + auto& nbins = context.Attr("bins"); + auto& minval = context.Attr("min"); + auto& maxval = context.Attr("max"); + + const T* input_data = input->data(); + const int input_numel = input->numel(); + + T output_min = static_cast(minval); + T output_max = static_cast(maxval); + + if (output_min == output_max) { + auto input_x = framework::EigenVector::Flatten(*input); + + framework::Tensor input_min_t, input_max_t; + auto* input_min_data = + input_min_t.mutable_data({1}, context.GetPlace()); + auto* input_max_data = + input_max_t.mutable_data({1}, context.GetPlace()); + auto input_min_scala = framework::EigenScalar::From(input_min_t); + auto input_max_scala = framework::EigenScalar::From(input_max_t); + + auto* place = + context.template device_context().eigen_device(); + input_min_scala.device(*place) = input_x.minimum(); + input_max_scala.device(*place) = input_x.maximum(); + + Tensor input_min_cpu, input_max_cpu; + TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu); + TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu); + + output_min = input_min_cpu.data()[0]; + output_max = input_max_cpu.data()[0]; + } + if (output_min == output_max) { + output_min = output_min - 1; + output_max = output_max + 1; + } + + PADDLE_ENFORCE_EQ( + (std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max)) || + std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max))), + false, platform::errors::OutOfRange("range of min, max is not finite")); + PADDLE_ENFORCE_GE( + output_max, output_min, + platform::errors::InvalidArgument( + "max must be larger or equal to min. If min and max are both zero, " + "the minimum and maximum values of the data are used. " + "But received max is %d, min is %d", + maxval, minval)); + + int64_t* out_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + auto stream = + context.template device_context().stream(); + KernelHistogram<<>>( + input_data, input_numel, nbins, output_min, output_max, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + histogram, + ops::HistogramCUDAKernel, + ops::HistogramCUDAKernel, + ops::HistogramCUDAKernel, + ops::HistogramCUDAKernel); diff --git a/paddle/fluid/operators/histogram_op.h b/paddle/fluid/operators/histogram_op.h new file mode 100644 index 00000000000..6e48c86d022 --- /dev/null +++ b/paddle/fluid/operators/histogram_op.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class HistogramKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + Tensor* output = context.Output("Out"); + auto& nbins = context.Attr("bins"); + auto& minval = context.Attr("min"); + auto& maxval = context.Attr("max"); + + const T* input_data = input->data(); + auto input_numel = input->numel(); + + T output_min = static_cast(minval); + T output_max = static_cast(maxval); + if (output_min == output_max) { + output_min = *std::min_element(input_data, input_data + input_numel); + output_max = *std::max_element(input_data, input_data + input_numel); + } + if (output_min == output_max) { + output_min = output_min - 1; + output_max = output_max + 1; + } + + PADDLE_ENFORCE_EQ( + (std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max)) || + std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max))), + false, platform::errors::OutOfRange("range of min, max is not finite")); + PADDLE_ENFORCE_GE( + output_max, output_min, + platform::errors::InvalidArgument( + "max must be larger or equal to min. If min and max are both zero, " + "the minimum and maximum values of the data are used. " + "But received max is %d, min is %d", + maxval, minval)); + + int64_t* out_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + out_data[std::min(bin, nbins - 1)] += 1; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 83f1d0439e1..9862ef8ac06 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -81,6 +81,7 @@ from .tensor.linalg import cross #DEFINE_ALIAS from .tensor.linalg import cholesky #DEFINE_ALIAS # from .tensor.linalg import tensordot #DEFINE_ALIAS from .tensor.linalg import bmm #DEFINE_ALIAS +from .tensor.linalg import histogram #DEFINE_ALIAS from .tensor.logic import equal #DEFINE_ALIAS from .tensor.logic import greater_equal #DEFINE_ALIAS from .tensor.logic import greater_than #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_histogram_op.py b/python/paddle/fluid/tests/unittests/test_histogram_op.py new file mode 100644 index 00000000000..0f880f2b035 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_histogram_op.py @@ -0,0 +1,87 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest + + +class TestHistogramOpAPI(unittest.TestCase): + """Test histogram api.""" + + def test_static_graph(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + inputs = fluid.data(name='input', dtype='int64', shape=[2, 3]) + output = paddle.histogram(inputs, bins=5, min=1, max=5) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) + res = exe.run(train_program, + feed={'input': img}, + fetch_list=[output]) + actual = np.array(res[0]) + expected = np.array([0, 3, 0, 2, 1]).astype(np.int64) + self.assertTrue( + (actual == expected).all(), + msg='histogram output is wrong, out =' + str(actual)) + + def test_dygraph(self): + with fluid.dygraph.guard(): + inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64) + inputs = fluid.dygraph.to_variable(inputs_np) + actual = paddle.histogram(inputs, bins=5, min=1, max=5) + expected = np.array([0, 3, 0, 2, 1]).astype(np.int64) + self.assertTrue( + (actual.numpy() == expected).all(), + msg='histogram output is wrong, out =' + str(actual.numpy())) + + +class TestHistogramOp(OpTest): + def setUp(self): + self.op_type = "histogram" + self.init_test_case() + np_input = np.random.randint( + low=0, high=20, size=self.in_shape, dtype=np.int64) + self.inputs = {"X": np_input} + self.init_attrs() + Out, _ = np.histogram( + np_input, bins=self.bins, range=(self.min, self.max)) + self.outputs = {"Out": Out.astype(np.int64)} + + def init_test_case(self): + self.in_shape = (10, 12) + self.bins = 5 + self.min = 1 + self.max = 5 + + def init_attrs(self): + self.attrs = {"bins": self.bins, "min": self.min, "max": self.max} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 2c310115449..a96d112c8ea 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -55,6 +55,7 @@ from .linalg import cross #DEFINE_ALIAS from .linalg import cholesky #DEFINE_ALIAS # from .linalg import tensordot #DEFINE_ALIAS from .linalg import bmm #DEFINE_ALIAS +from .linalg import histogram #DEFINE_ALIAS from .logic import equal #DEFINE_ALIAS from .logic import greater_equal #DEFINE_ALIAS from .logic import greater_than #DEFINE_ALIAS diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 85506f1b7be..f0e1c78f117 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -30,7 +30,8 @@ __all__ = [ 'cross', 'cholesky', # 'tensordot', - 'bmm' + 'bmm', + 'histogram' ] @@ -751,3 +752,63 @@ def bmm(x, y, name=None): out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type='bmm', inputs={'X': x, 'Y': y}, outputs={'Out': out}) return out + + +def histogram(input, bins=100, min=0, max=0): + """ + Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max. + If min and max are both zero, the minimum and maximum values of the data are used. + + Args: + input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + should be float32, float64, int32, int64. + bins (int): number of histogram bins + min (int): lower end of the range (inclusive) + max (int): upper end of the range (inclusive) + + Returns: + Variable: Tensor or LoDTensor calculated by histogram layer. The data type is int64. + + Code Example 1: + .. code-block:: python + import paddle + import numpy as np + startup_program = paddle.Program() + train_program = paddle.Program() + with paddle.program_guard(train_program, startup_program): + inputs = paddle.data(name='input', dtype='int32', shape=[2,3]) + output = paddle.histogram(inputs, bins=5, min=1, max=5) + place = paddle.CPUPlace() + exe = paddle.Executor(place) + exe.run(startup_program) + img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int32) + res = exe.run(train_program, + feed={'input': img}, + fetch_list=[output]) + print(np.array(res[0])) # [0,3,0,2,1] + + Code Example 2: + .. code-block:: python + import paddle + import numpy as np + with paddle.imperative.guard(paddle.CPUPlace()): + inputs_np = np.array([1, 2, 1]).astype(np.float) + inputs = paddle.imperative.to_variable(inputs_np) + result = paddle.histogram(inputs, bins=4, min=0, max=3) + print(result) # [0, 2, 1, 0] + """ + if in_dygraph_mode(): + return core.ops.histogram(input, "bins", bins, "min", min, "max", max) + + helper = LayerHelper('histogram', **locals()) + check_variable_and_dtype( + input, 'X', ['int32', 'int64', 'float32', 'float64'], 'histogram') + out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) + helper.append_op( + type='histogram', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'bins': bins, + 'min': min, + 'max': max}) + return out -- GitLab