未验证 提交 704cad6a 编写于 作者: Q Qi Li 提交者: GitHub

Add histc op (#24562)

* add histc operator, test=develop

* update english doc to 2.0 API, test=develop

* update API from histc to histogram, test=develop
Co-authored-by: Nroot <root@yq01-gpu-255-129-15-00.epc.baidu.com>
上级 1f032c53
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/histogram_op.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class HistogramOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "histogram");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "histogram");
const auto &nbins = ctx->Attrs().Get<int64_t>("bins");
const auto &minval = ctx->Attrs().Get<int>("min");
const auto &maxval = ctx->Attrs().Get<int>("max");
PADDLE_ENFORCE_GE(nbins, 1,
platform::errors::InvalidArgument(
"The bins should be greater than or equal to 1."
"But received nbins is %d",
nbins));
PADDLE_ENFORCE_GE(maxval, minval, platform::errors::InvalidArgument(
"max must be larger or equal to min."
"But received max is %d, min is %d",
maxval, minval));
ctx->SetOutputDim("Out", framework::make_ddim({nbins}));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class HistogramOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensor of Histogram op,");
AddOutput("Out", "(Tensor) The output tensor of Histogram op,");
AddAttr<int64_t>("bins", "(int) number of histogram bins")
.SetDefault(100)
.EqualGreaterThan(1);
AddAttr<int>("min", "(int) lower end of the range (inclusive)")
.SetDefault(0);
AddAttr<int>("max", "(int) upper end of the range (inclusive)")
.SetDefault(0);
AddComment(R"DOC(
Histogram Operator.
Computes the histogram of a tensor. The elements are sorted
into equal width bins between min and max. If min and max are
both zero, the minimum and maximum values of the data are used.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
histogram, ops::HistogramOp, ops::HistogramOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
histogram, ops::HistogramKernel<paddle::platform::CPUDeviceContext, float>,
ops::HistogramKernel<paddle::platform::CPUDeviceContext, double>,
ops::HistogramKernel<paddle::platform::CPUDeviceContext, int>,
ops::HistogramKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/histogram_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using IndexType = int64_t;
using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
template <typename T, typename IndexType>
__device__ static IndexType GetBin(T bVal, T minvalue, T maxvalue,
int64_t nbins) {
IndexType bin =
static_cast<int>((bVal - minvalue) * nbins / (maxvalue - minvalue));
if (bin == nbins) bin -= 1;
return bin;
}
template <typename T, typename IndexType>
__global__ void KernelHistogram(const T* input, const int totalElements,
const int64_t nbins, const T minvalue,
const T maxvalue, int64_t* output) {
CUDA_KERNEL_LOOP(linearIndex, totalElements) {
const IndexType inputIdx = threadIdx.x + blockIdx.x * blockDim.x;
const auto inputVal = input[inputIdx];
if (inputVal >= minvalue && inputVal <= maxvalue) {
const IndexType bin =
GetBin<T, IndexType>(inputVal, minvalue, maxvalue, nbins);
const IndexType outputIdx = bin < nbins - 1 ? bin : nbins - 1;
paddle::platform::CudaAtomicAdd(&output[outputIdx], 1);
}
}
}
template <typename DeviceContext, typename T>
class HistogramCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
const Tensor* input = context.Input<framework::Tensor>("X");
Tensor* output = context.Output<framework::Tensor>("Out");
auto& nbins = context.Attr<int64_t>("bins");
auto& minval = context.Attr<int>("min");
auto& maxval = context.Attr<int>("max");
const T* input_data = input->data<T>();
const int input_numel = input->numel();
T output_min = static_cast<T>(minval);
T output_max = static_cast<T>(maxval);
if (output_min == output_max) {
auto input_x = framework::EigenVector<T>::Flatten(*input);
framework::Tensor input_min_t, input_max_t;
auto* input_min_data =
input_min_t.mutable_data<T>({1}, context.GetPlace());
auto* input_max_data =
input_max_t.mutable_data<T>({1}, context.GetPlace());
auto input_min_scala = framework::EigenScalar<T>::From(input_min_t);
auto input_max_scala = framework::EigenScalar<T>::From(input_max_t);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
input_min_scala.device(*place) = input_x.minimum();
input_max_scala.device(*place) = input_x.maximum();
Tensor input_min_cpu, input_max_cpu;
TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);
TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
output_min = input_min_cpu.data<T>()[0];
output_max = input_max_cpu.data<T>()[0];
}
if (output_min == output_max) {
output_min = output_min - 1;
output_max = output_max + 1;
}
PADDLE_ENFORCE_EQ(
(std::isinf(static_cast<float>(output_min)) ||
std::isnan(static_cast<float>(output_max)) ||
std::isinf(static_cast<float>(output_min)) ||
std::isnan(static_cast<float>(output_max))),
false, platform::errors::OutOfRange("range of min, max is not finite"));
PADDLE_ENFORCE_GE(
output_max, output_min,
platform::errors::InvalidArgument(
"max must be larger or equal to min. If min and max are both zero, "
"the minimum and maximum values of the data are used. "
"But received max is %d, min is %d",
maxval, minval));
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, int64_t>()(
context.template device_context<platform::CUDADeviceContext>(), output,
static_cast<int64_t>(0));
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
KernelHistogram<T, IndexType><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, nbins, output_min, output_max, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
histogram,
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class HistogramKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<framework::Tensor>("X");
Tensor* output = context.Output<framework::Tensor>("Out");
auto& nbins = context.Attr<int64_t>("bins");
auto& minval = context.Attr<int>("min");
auto& maxval = context.Attr<int>("max");
const T* input_data = input->data<T>();
auto input_numel = input->numel();
T output_min = static_cast<T>(minval);
T output_max = static_cast<T>(maxval);
if (output_min == output_max) {
output_min = *std::min_element(input_data, input_data + input_numel);
output_max = *std::max_element(input_data, input_data + input_numel);
}
if (output_min == output_max) {
output_min = output_min - 1;
output_max = output_max + 1;
}
PADDLE_ENFORCE_EQ(
(std::isinf(static_cast<float>(output_min)) ||
std::isnan(static_cast<float>(output_max)) ||
std::isinf(static_cast<float>(output_min)) ||
std::isnan(static_cast<float>(output_max))),
false, platform::errors::OutOfRange("range of min, max is not finite"));
PADDLE_ENFORCE_GE(
output_max, output_min,
platform::errors::InvalidArgument(
"max must be larger or equal to min. If min and max are both zero, "
"the minimum and maximum values of the data are used. "
"But received max is %d, min is %d",
maxval, minval));
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output,
static_cast<int64_t>(0));
for (int64_t i = 0; i < input_numel; i++) {
if (input_data[i] >= output_min && input_data[i] <= output_max) {
const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /
(output_max - output_min));
out_data[std::min(bin, nbins - 1)] += 1;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -81,6 +81,7 @@ from .tensor.linalg import cross #DEFINE_ALIAS
from .tensor.linalg import cholesky #DEFINE_ALIAS
# from .tensor.linalg import tensordot #DEFINE_ALIAS
from .tensor.linalg import bmm #DEFINE_ALIAS
from .tensor.linalg import histogram #DEFINE_ALIAS
from .tensor.logic import equal #DEFINE_ALIAS
from .tensor.logic import greater_equal #DEFINE_ALIAS
from .tensor.logic import greater_than #DEFINE_ALIAS
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest
class TestHistogramOpAPI(unittest.TestCase):
"""Test histogram api."""
def test_static_graph(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
inputs = fluid.data(name='input', dtype='int64', shape=[2, 3])
output = paddle.histogram(inputs, bins=5, min=1, max=5)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64)
res = exe.run(train_program,
feed={'input': img},
fetch_list=[output])
actual = np.array(res[0])
expected = np.array([0, 3, 0, 2, 1]).astype(np.int64)
self.assertTrue(
(actual == expected).all(),
msg='histogram output is wrong, out =' + str(actual))
def test_dygraph(self):
with fluid.dygraph.guard():
inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64)
inputs = fluid.dygraph.to_variable(inputs_np)
actual = paddle.histogram(inputs, bins=5, min=1, max=5)
expected = np.array([0, 3, 0, 2, 1]).astype(np.int64)
self.assertTrue(
(actual.numpy() == expected).all(),
msg='histogram output is wrong, out =' + str(actual.numpy()))
class TestHistogramOp(OpTest):
def setUp(self):
self.op_type = "histogram"
self.init_test_case()
np_input = np.random.randint(
low=0, high=20, size=self.in_shape, dtype=np.int64)
self.inputs = {"X": np_input}
self.init_attrs()
Out, _ = np.histogram(
np_input, bins=self.bins, range=(self.min, self.max))
self.outputs = {"Out": Out.astype(np.int64)}
def init_test_case(self):
self.in_shape = (10, 12)
self.bins = 5
self.min = 1
self.max = 5
def init_attrs(self):
self.attrs = {"bins": self.bins, "min": self.min, "max": self.max}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
......@@ -55,6 +55,7 @@ from .linalg import cross #DEFINE_ALIAS
from .linalg import cholesky #DEFINE_ALIAS
# from .linalg import tensordot #DEFINE_ALIAS
from .linalg import bmm #DEFINE_ALIAS
from .linalg import histogram #DEFINE_ALIAS
from .logic import equal #DEFINE_ALIAS
from .logic import greater_equal #DEFINE_ALIAS
from .logic import greater_than #DEFINE_ALIAS
......
......@@ -30,7 +30,8 @@ __all__ = [
'cross',
'cholesky',
# 'tensordot',
'bmm'
'bmm',
'histogram'
]
......@@ -751,3 +752,63 @@ def bmm(x, y, name=None):
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='bmm', inputs={'X': x, 'Y': y}, outputs={'Out': out})
return out
def histogram(input, bins=100, min=0, max=0):
"""
Computes the histogram of a tensor. The elements are sorted into equal width bins between min and max.
If min and max are both zero, the minimum and maximum values of the data are used.
Args:
input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64.
bins (int): number of histogram bins
min (int): lower end of the range (inclusive)
max (int): upper end of the range (inclusive)
Returns:
Variable: Tensor or LoDTensor calculated by histogram layer. The data type is int64.
Code Example 1:
.. code-block:: python
import paddle
import numpy as np
startup_program = paddle.Program()
train_program = paddle.Program()
with paddle.program_guard(train_program, startup_program):
inputs = paddle.data(name='input', dtype='int32', shape=[2,3])
output = paddle.histogram(inputs, bins=5, min=1, max=5)
place = paddle.CPUPlace()
exe = paddle.Executor(place)
exe.run(startup_program)
img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int32)
res = exe.run(train_program,
feed={'input': img},
fetch_list=[output])
print(np.array(res[0])) # [0,3,0,2,1]
Code Example 2:
.. code-block:: python
import paddle
import numpy as np
with paddle.imperative.guard(paddle.CPUPlace()):
inputs_np = np.array([1, 2, 1]).astype(np.float)
inputs = paddle.imperative.to_variable(inputs_np)
result = paddle.histogram(inputs, bins=4, min=0, max=3)
print(result) # [0, 2, 1, 0]
"""
if in_dygraph_mode():
return core.ops.histogram(input, "bins", bins, "min", min, "max", max)
helper = LayerHelper('histogram', **locals())
check_variable_and_dtype(
input, 'X', ['int32', 'int64', 'float32', 'float64'], 'histogram')
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(
type='histogram',
inputs={'X': input},
outputs={'Out': out},
attrs={'bins': bins,
'min': min,
'max': max})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册