未验证 提交 12917c8c 编写于 作者: W WangZhen 提交者: GitHub

[OpAttr]Adapt tensor minlength for bincount (#45342)

* Adapt minlength attr for bincount
上级 8da6b72b
...@@ -50,7 +50,8 @@ class BincountOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,7 +50,8 @@ class BincountOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor) The output tensor of Bincount op,"); AddOutput("Out", "(Tensor) The output tensor of Bincount op,");
AddAttr<int>("minlength", "(int) The minimal numbers of bins") AddAttr<int>("minlength", "(int) The minimal numbers of bins")
.SetDefault(0) .SetDefault(0)
.EqualGreaterThan(0); .EqualGreaterThan(0)
.SupportTensor();
AddComment(R"DOC( AddComment(R"DOC(
Bincount Operator. Bincount Operator.
Computes frequency of each value in the input tensor. Computes frequency of each value in the input tensor.
......
...@@ -210,17 +210,10 @@ void BCELossInferMeta(const MetaTensor& input, ...@@ -210,17 +210,10 @@ void BCELossInferMeta(const MetaTensor& input,
void BincountInferMeta(const MetaTensor& x, void BincountInferMeta(const MetaTensor& x,
const MetaTensor& weights, const MetaTensor& weights,
int minlength, const Scalar& minlength,
MetaTensor* out) { MetaTensor* out) {
auto input_dim = x.dims(); auto input_dim = x.dims();
PADDLE_ENFORCE_GE(minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
minlength));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_dim.size(), input_dim.size(),
1, 1,
......
...@@ -57,7 +57,7 @@ void BCELossInferMeta(const MetaTensor& input, ...@@ -57,7 +57,7 @@ void BCELossInferMeta(const MetaTensor& input,
void BincountInferMeta(const MetaTensor& x, void BincountInferMeta(const MetaTensor& x,
const MetaTensor& weights, const MetaTensor& weights,
int minlength, const Scalar& minlength,
MetaTensor* out); MetaTensor* out);
void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
...@@ -22,7 +23,7 @@ template <typename T, typename Context> ...@@ -22,7 +23,7 @@ template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx, void BincountKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const paddle::optional<DenseTensor>& weights, const paddle::optional<DenseTensor>& weights,
int minlength, const Scalar& minlength,
DenseTensor* out); DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -86,12 +86,20 @@ template <typename T, typename Context> ...@@ -86,12 +86,20 @@ template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx, void BincountKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const paddle::optional<DenseTensor>& weights, const paddle::optional<DenseTensor>& weights,
int minlength, const Scalar& minlength,
DenseTensor* out) { DenseTensor* out) {
int int_minlength = minlength.to<int>();
PADDLE_ENFORCE_GE(int_minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
int_minlength));
if (x.dtype() == DataType::INT32) { if (x.dtype() == DataType::INT32) {
BincountInner<Context, T, int>(dev_ctx, x, weights, minlength, out); BincountInner<Context, T, int>(dev_ctx, x, weights, int_minlength, out);
} else if (x.dtype() == DataType::INT64) { } else if (x.dtype() == DataType::INT64) {
BincountInner<Context, T, int64_t>(dev_ctx, x, weights, minlength, out); BincountInner<Context, T, int64_t>(dev_ctx, x, weights, int_minlength, out);
} }
} }
} // namespace phi } // namespace phi
......
...@@ -138,12 +138,21 @@ template <typename T, typename Context> ...@@ -138,12 +138,21 @@ template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx, void BincountKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const paddle::optional<DenseTensor>& weights, const paddle::optional<DenseTensor>& weights,
int minlength, const Scalar& minlength,
DenseTensor* out) { DenseTensor* out) {
int int_minlength = minlength.to<int>();
PADDLE_ENFORCE_GE(int_minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
int_minlength));
if (x.dtype() == DataType::INT32) { if (x.dtype() == DataType::INT32) {
BincountCUDAInner<Context, T, int>(dev_ctx, x, weights, minlength, out); BincountCUDAInner<Context, T, int>(dev_ctx, x, weights, int_minlength, out);
} else if (x.dtype() == DataType::INT64) { } else if (x.dtype() == DataType::INT64) {
BincountCUDAInner<Context, T, int64_t>(dev_ctx, x, weights, minlength, out); BincountCUDAInner<Context, T, int64_t>(
dev_ctx, x, weights, int_minlength, out);
} }
} }
} // namespace phi } // namespace phi
......
...@@ -14,13 +14,16 @@ ...@@ -14,13 +14,16 @@
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import tempfile
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from op_test import OpTest from op_test import OpTest
import paddle.inference as paddle_infer
paddle.enable_static() paddle.enable_static()
...@@ -206,5 +209,66 @@ class TestCase5(TestBincountOp): ...@@ -206,5 +209,66 @@ class TestCase5(TestBincountOp):
self.Out = np.bincount(self.np_input, minlength=self.minlength) self.Out = np.bincount(self.np_input, minlength=self.minlength)
class TestTensorMinlength(unittest.TestCase):
def setUp(self):
paddle.disable_static()
paddle.seed(2022)
self.temp_dir = tempfile.TemporaryDirectory()
self.save_path = os.path.join(self.temp_dir.name,
'tensor_minlength_bincount')
self.place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
def test_dygraph(self):
paddle.disable_static()
x = np.random.randint(0, 10, [20])
minlength = 2
np_out = np.bincount(x, minlength=minlength)
pd_out = paddle.bincount(paddle.to_tensor(x),
minlength=paddle.to_tensor([2], dtype='int32'))
np.testing.assert_allclose(np_out, pd_out.numpy())
def test_static_and_infer(self):
paddle.enable_static()
np_x = np.random.randn(100).astype('float32')
main_prog = paddle.static.Program()
starup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, starup_prog):
# run static
x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype)
linear = paddle.nn.Linear(np_x.shape[0], np_x.shape[0])
linear_out = linear(x)
relu_out = paddle.nn.functional.relu(linear_out)
minlength = paddle.full([1], 3, dtype='int32')
out = paddle.bincount(paddle.cast(relu_out, 'int32'),
minlength=minlength)
exe = paddle.static.Executor(self.place)
exe.run(starup_prog)
static_out = exe.run(feed={'x': np_x}, fetch_list=[out])
# run infer
paddle.static.save_inference_model(self.save_path, [x], [out], exe)
config = paddle_infer.Config(self.save_path + '.pdmodel',
self.save_path + '.pdiparams')
if paddle.is_compiled_with_cuda():
config.enable_use_gpu(100, 0)
else:
config.disable_gpu()
predictor = paddle_infer.create_predictor(config)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
fake_input = np_x
input_handle.reshape(np_x.shape)
input_handle.copy_from_cpu(fake_input)
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
infer_out = output_handle.copy_to_cpu()
np.testing.assert_allclose(static_out[0], infer_out)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册