未验证 提交 610a810c 编写于 作者: S smallv0221 提交者: GitHub

Add bincount op (#36317) (#36709)

* Add bincount op

* upload cpu version

* fix unitest

* fix unittest

* fix unittest

* fix en doc

* add more test

* fix en doc

* add more test case

* fix test

* fix input vailidation

* fix input check

* fix unittest

* fix test

* fix en doc

cherry-pick
上级 616ce203
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bincount_op.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class BincountOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BincountOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of BincountOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto minlength = ctx->Attrs().Get<int>("minlength");
PADDLE_ENFORCE_GE(minlength, 0,
platform::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
minlength));
PADDLE_ENFORCE_EQ(input_dim.size(), 1,
platform::errors::InvalidArgument(
"The 'shape' of Input(X) must be 1-D tensor."
"But the dimension of Input(X) is [%d]",
input_dim.size()));
if (ctx->HasInput("Weights")) {
auto weights_dim = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dim.size(), 1,
platform::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be 1-D tensor."
"But the dimension of Input(Weights) is [%d]",
weights_dim.size()));
PADDLE_ENFORCE_EQ(
weights_dim[0], input_dim[0],
platform::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
"Input(X)."
"But received: the 'shape' of Input(Weights) is [%s],"
"the 'shape' of Input(X) is [%s]",
weights_dim, input_dim));
}
ctx->SetOutputDim("Out", framework::make_ddim({-1}));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type =
ctx.HasInput("Weights")
? OperatorWithKernel::IndicateVarDataType(ctx, "Weights")
: OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class BincountOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensor of Bincount op,");
AddInput("Weights", "(Tensor) The weights tensor of Bincount op,")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of Bincount op,");
AddAttr<int>("minlength", "(int) The minimal numbers of bins")
.SetDefault(0)
.EqualGreaterThan(0);
AddComment(R"DOC(
Bincount Operator.
Computes frequency of each value in the input tensor.
Elements of input tensor should be non-negative ints.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
bincount, ops::BincountOp, ops::BincountOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
bincount, ops::BincountKernel<paddle::platform::CPUDeviceContext, float>,
ops::BincountKernel<paddle::platform::CPUDeviceContext, double>,
ops::BincountKernel<paddle::platform::CPUDeviceContext, int>,
ops::BincountKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/bincount_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
template <typename T, typename InputT, typename OutT>
__global__ void KernelBincount(const InputT* input, const int total_elements,
const bool has_weights, const T* weights,
OutT* output) {
if (!has_weights) {
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[input[i]], 1L);
}
} else {
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[input[i]],
static_cast<OutT>(weights[i]));
}
}
}
template <typename DeviceContext, typename T, typename InputT>
void BincountCUDAInner(const framework::ExecutionContext& context) {
const Tensor* input = context.Input<framework::Tensor>("X");
const Tensor* weights = context.Input<framework::Tensor>("Weights");
Tensor* output = context.Output<framework::Tensor>("Out");
auto& minlength = context.Attr<int>("minlength");
const InputT* input_data = input->data<InputT>();
const int input_numel = input->numel();
if (input_data == nullptr) {
framework::DDim out_dim{0};
output->Resize(out_dim);
output->mutable_data<T>(context.GetPlace());
return;
}
auto input_x = framework::EigenVector<InputT>::Flatten(*input);
framework::Tensor input_min_t, input_max_t;
auto* input_max_data =
input_max_t.mutable_data<InputT>({1}, context.GetPlace());
auto* input_min_data =
input_min_t.mutable_data<InputT>({1}, context.GetPlace());
auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t);
auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t);
auto* place = context.template device_context<DeviceContext>().eigen_device();
input_max_scala.device(*place) = input_x.maximum();
input_min_scala.device(*place) = input_x.minimum();
Tensor input_min_cpu, input_max_cpu;
TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);
InputT input_min = input_min_cpu.data<InputT>()[0];
PADDLE_ENFORCE_GE(
input_min, static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
int64_t output_size =
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;
output_size = std::max(output_size, static_cast<int64_t>(minlength));
framework::DDim out_dim{output_size};
output->Resize(out_dim);
bool has_weights = (weights != nullptr);
const T* weights_data = has_weights ? weights->data<T>() : nullptr;
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
if (!has_weights) {
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output, 0L);
KernelBincount<T, InputT, int64_t><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
const auto& weights_type = weights->type();
if (weights_type == framework::proto::VarType::FP32) {
float* output_data = output->mutable_data<float>(context.GetPlace());
math::SetConstant<DeviceContext, float>()(
context.template device_context<DeviceContext>(), output,
static_cast<float>(0));
KernelBincount<T, InputT, float><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
double* output_data = output->mutable_data<double>(context.GetPlace());
math::SetConstant<DeviceContext, double>()(
context.template device_context<DeviceContext>(), output,
static_cast<double>(0));
KernelBincount<T, InputT, double><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
}
}
}
template <typename DeviceContext, typename T>
class BincountCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<framework::Tensor>("X");
const auto& input_type = input->type();
if (input_type == framework::proto::VarType::INT32) {
BincountCUDAInner<DeviceContext, T, int>(context);
} else if (input_type == framework::proto::VarType::INT64) {
BincountCUDAInner<DeviceContext, T, int64_t>(context);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
bincount, ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T, typename InputT>
void BincountInner(const framework::ExecutionContext& context) {
const Tensor* input = context.Input<framework::Tensor>("X");
const Tensor* weights = context.Input<framework::Tensor>("Weights");
Tensor* output = context.Output<framework::Tensor>("Out");
auto& minlength = context.Attr<int>("minlength");
const InputT* input_data = input->data<InputT>();
auto input_numel = input->numel();
if (input_data == nullptr) {
framework::DDim out_dim{0};
output->Resize(out_dim);
output->mutable_data<InputT>(context.GetPlace());
return;
}
PADDLE_ENFORCE_GE(
*std::min_element(input_data, input_data + input_numel),
static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
int64_t output_size = static_cast<int64_t>(*std::max_element(
input_data, input_data + input_numel)) +
1L;
output_size = std::max(output_size, static_cast<int64_t>(minlength));
framework::DDim out_dim{output_size};
output->Resize(out_dim);
bool has_weights = (weights != nullptr);
if (has_weights) {
const T* weights_data = weights->data<T>();
const auto& weights_type = weights->type();
if (weights_type == framework::proto::VarType::FP32) {
float* output_data = output->mutable_data<float>(context.GetPlace());
math::SetConstant<DeviceContext, float>()(
context.template device_context<DeviceContext>(), output,
static_cast<float>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<float>(weights_data[i]);
}
} else {
double* output_data = output->mutable_data<double>(context.GetPlace());
math::SetConstant<DeviceContext, double>()(
context.template device_context<DeviceContext>(), output,
static_cast<double>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<double>(weights_data[i]);
}
}
} else {
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output, 0L);
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += 1L;
}
}
}
template <typename DeviceContext, typename T>
class BincountKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<framework::Tensor>("X");
const auto& input_type = input->type();
if (input_type == framework::proto::VarType::INT32) {
BincountInner<DeviceContext, T, int>(context);
} else if (input_type == framework::proto::VarType::INT64) {
BincountInner<DeviceContext, T, int64_t>(context);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -40,6 +40,7 @@
// need to manually specify them in this map.
std::map<std::string, std::set<std::string>> op_ins_map = {
{"layer_norm", {"X", "Scale", "Bias"}},
{"bincount", {"X", "Weights"}},
{"fused_attention",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW",
"OutLinearBias", "Ln2Scale", "Ln2Bias"}},
......
......@@ -99,6 +99,7 @@ from .tensor.linalg import cross # noqa: F401
from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import bmm # noqa: F401
from .tensor.linalg import histogram # noqa: F401
from .tensor.linalg import bincount # noqa: F401
from .tensor.linalg import mv # noqa: F401
from .tensor.logic import equal # noqa: F401
from .tensor.logic import greater_equal # noqa: F401
......@@ -398,6 +399,7 @@ __all__ = [ # noqa
'bitwise_not',
'mm',
'flip',
'bincount',
'histogram',
'multiplex',
'CUDAPlace',
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest
paddle.enable_static()
class TestBincountOpAPI(unittest.TestCase):
"""Test bincount api."""
def test_static_graph(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
inputs = fluid.data(name='input', dtype='int64', shape=[7])
weights = fluid.data(name='weights', dtype='int64', shape=[7])
output = paddle.bincount(inputs, weights=weights)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
img = np.array([0, 1, 1, 3, 2, 1, 7]).astype(np.int64)
w = np.array([0, 1, 1, 2, 2, 1, 0]).astype(np.int64)
res = exe.run(train_program,
feed={'input': img,
'weights': w},
fetch_list=[output])
actual = np.array(res[0])
expected = np.bincount(img, weights=w)
self.assertTrue(
(actual == expected).all(),
msg='bincount output is wrong, out =' + str(actual))
def test_dygraph(self):
with fluid.dygraph.guard():
inputs_np = np.array([0, 1, 1, 3, 2, 1, 7]).astype(np.int64)
inputs = fluid.dygraph.to_variable(inputs_np)
actual = paddle.bincount(inputs)
expected = np.bincount(inputs)
self.assertTrue(
(actual.numpy() == expected).all(),
msg='bincount output is wrong, out =' + str(actual.numpy()))
class TestBincountOpError(unittest.TestCase):
"""Test bincount op error."""
def run_network(self, net_func):
with fluid.dygraph.guard():
net_func()
def test_input_value_error(self):
"""Test input tensor should be non-negative."""
def net_func():
input_value = paddle.to_tensor([1, 2, 3, 4, -5])
paddle.bincount(input_value)
with self.assertRaises(ValueError):
self.run_network(net_func)
def test_input_shape_error(self):
"""Test input tensor should be 1-D tansor."""
def net_func():
input_value = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
paddle.bincount(input_value)
with self.assertRaises(ValueError):
self.run_network(net_func)
def test_minlength_value_error(self):
"""Test minlength is non-negative ints."""
def net_func():
input_value = paddle.to_tensor([1, 2, 3, 4, 5])
paddle.bincount(input_value, minlength=-1)
with self.assertRaises(IndexError):
self.run_network(net_func)
def test_input_type_errors(self):
"""Test input tensor should only contain non-negative ints."""
def net_func():
input_value = paddle.to_tensor([1., 2., 3., 4., 5.])
paddle.bincount(input_value)
with self.assertRaises(TypeError):
self.run_network(net_func)
def test_weights_shape_error(self):
"""Test weights tensor should have the same shape as input tensor."""
def net_func():
input_value = paddle.to_tensor([1, 2, 3, 4, 5])
weights = paddle.to_tensor([1, 1, 1, 1, 1, 1])
paddle.bincount(input_value, weights=weights)
with self.assertRaises(ValueError):
self.run_network(net_func)
class TestBincountOp(OpTest):
# without weights
def setUp(self):
self.op_type = "bincount"
self.init_test_case()
self.inputs = {"X": self.np_input}
self.attrs = {"minlength": self.minlength}
self.outputs = {"Out": self.Out}
def init_test_case(self):
self.minlength = 0
self.np_input = np.random.randint(low=0, high=20, size=10)
self.Out = np.bincount(self.np_input, minlength=self.minlength)
def test_check_output(self):
self.check_output()
class TestCase1(TestBincountOp):
# with weights(FLOAT32)
def setUp(self):
self.op_type = "bincount"
self.init_test_case()
self.inputs = {"X": self.np_input, "Weights": self.np_weights}
self.attrs = {"minlength": self.minlength}
self.outputs = {"Out": self.Out}
def init_test_case(self):
self.minlength = 0
self.np_weights = np.random.randint(
low=0, high=20, size=10).astype(np.float32)
self.np_input = np.random.randint(low=0, high=20, size=10)
self.Out = np.bincount(
self.np_input, weights=self.np_weights,
minlength=self.minlength).astype(np.float32)
class TestCase2(TestBincountOp):
# with weights(other)
def setUp(self):
self.op_type = "bincount"
self.init_test_case()
self.inputs = {"X": self.np_input, "Weights": self.np_weights}
self.attrs = {"minlength": self.minlength}
self.outputs = {"Out": self.Out}
def init_test_case(self):
self.minlength = 0
self.np_weights = np.random.randint(low=0, high=20, size=10)
self.np_input = np.random.randint(low=0, high=20, size=10)
self.Out = np.bincount(
self.np_input, weights=self.np_weights, minlength=self.minlength)
class TestCase3(TestBincountOp):
# empty input
def init_test_case(self):
self.minlength = 0
self.np_input = np.array([], dtype=np.int64)
self.Out = np.bincount(self.np_input, minlength=self.minlength)
class TestCase4(TestBincountOp):
# with input(INT32)
def init_test_case(self):
self.minlength = 0
self.np_input = np.random.randint(
low=0, high=20, size=10).astype(np.int32)
self.Out = np.bincount(self.np_input, minlength=self.minlength)
class TestCase5(TestBincountOp):
# with minlength greater than max(X)
def init_test_case(self):
self.minlength = 20
self.np_input = np.random.randint(low=0, high=10, size=10)
self.Out = np.bincount(self.np_input, minlength=self.minlength)
if __name__ == "__main__":
unittest.main()
......@@ -44,6 +44,7 @@ from .linalg import cross # noqa: F401
from .linalg import cholesky # noqa: F401
from .linalg import bmm # noqa: F401
from .linalg import histogram # noqa: F401
from .linalg import bincount # noqa: F401
from .linalg import mv # noqa: F401
from .linalg import eig # noqa: F401
from .linalg import matrix_power # noqa: F401
......@@ -236,6 +237,7 @@ tensor_method_func = [ #noqa
'cholesky',
'bmm',
'histogram',
'bincount',
'mv',
'matrix_power',
'qr',
......
......@@ -1293,6 +1293,59 @@ def histogram(input, bins=100, min=0, max=0, name=None):
return out
def bincount(x, weights=None, minlength=0, name=None):
"""
Computes frequency of each value in the input tensor.
Args:
x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor.
weights (Tensor, optional): Weight for each value in the input tensor. Should have the same shape as input. Default is None.
minlength (int, optional): Minimum number of bins. Should be non-negative integer. Default is 0.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The tensor of frequency.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1, 2, 1, 4, 5])
result1 = paddle.bincount(x)
print(result1) # [0, 2, 1, 0, 1, 1]
w = paddle.to_tensor([2.1, 0.4, 0.1, 0.5, 0.5])
result2 = paddle.bincount(x, weights=w)
print(result2) # [0., 2.19999981, 0.40000001, 0., 0.50000000, 0.50000000]
"""
if x.dtype not in [paddle.int32, paddle.int64]:
raise TypeError("Elements in Input(x) should all be integers")
if in_dygraph_mode():
return _C_ops.bincount(x, weights, "minlength", minlength)
helper = LayerHelper('bincount', **locals())
check_variable_and_dtype(x, 'X', ['int32', 'int64'], 'bincount')
if weights is not None:
check_variable_and_dtype(weights, 'Weights',
['int32', 'int64', 'float32', 'float64'],
'bincount')
out = helper.create_variable_for_type_inference(dtype=weights.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='bincount',
inputs={'X': x,
'Weights': weights},
outputs={'Out': out},
attrs={'minlength': minlength})
return out
def mv(x, vec, name=None):
"""
Performs a matrix-vector product of the matrix x and the vector vec.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册