未验证 提交 12917c8c 编写于 作者: W WangZhen 提交者: GitHub

[OpAttr]Adapt tensor minlength for bincount (#45342)

* Adapt minlength attr for bincount
上级 8da6b72b
......@@ -50,7 +50,8 @@ class BincountOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor) The output tensor of Bincount op,");
AddAttr<int>("minlength", "(int) The minimal numbers of bins")
.SetDefault(0)
.EqualGreaterThan(0);
.EqualGreaterThan(0)
.SupportTensor();
AddComment(R"DOC(
Bincount Operator.
Computes frequency of each value in the input tensor.
......
......@@ -210,17 +210,10 @@ void BCELossInferMeta(const MetaTensor& input,
void BincountInferMeta(const MetaTensor& x,
const MetaTensor& weights,
int minlength,
const Scalar& minlength,
MetaTensor* out) {
auto input_dim = x.dims();
PADDLE_ENFORCE_GE(minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
minlength));
PADDLE_ENFORCE_EQ(
input_dim.size(),
1,
......
......@@ -57,7 +57,7 @@ void BCELossInferMeta(const MetaTensor& input,
void BincountInferMeta(const MetaTensor& x,
const MetaTensor& weights,
int minlength,
const Scalar& minlength,
MetaTensor* out);
void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......@@ -22,7 +23,7 @@ template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& weights,
int minlength,
const Scalar& minlength,
DenseTensor* out);
} // namespace phi
......@@ -86,12 +86,20 @@ template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& weights,
int minlength,
const Scalar& minlength,
DenseTensor* out) {
int int_minlength = minlength.to<int>();
PADDLE_ENFORCE_GE(int_minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
int_minlength));
if (x.dtype() == DataType::INT32) {
BincountInner<Context, T, int>(dev_ctx, x, weights, minlength, out);
BincountInner<Context, T, int>(dev_ctx, x, weights, int_minlength, out);
} else if (x.dtype() == DataType::INT64) {
BincountInner<Context, T, int64_t>(dev_ctx, x, weights, minlength, out);
BincountInner<Context, T, int64_t>(dev_ctx, x, weights, int_minlength, out);
}
}
} // namespace phi
......
......@@ -138,12 +138,21 @@ template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& weights,
int minlength,
const Scalar& minlength,
DenseTensor* out) {
int int_minlength = minlength.to<int>();
PADDLE_ENFORCE_GE(int_minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
int_minlength));
if (x.dtype() == DataType::INT32) {
BincountCUDAInner<Context, T, int>(dev_ctx, x, weights, minlength, out);
BincountCUDAInner<Context, T, int>(dev_ctx, x, weights, int_minlength, out);
} else if (x.dtype() == DataType::INT64) {
BincountCUDAInner<Context, T, int64_t>(dev_ctx, x, weights, minlength, out);
BincountCUDAInner<Context, T, int64_t>(
dev_ctx, x, weights, int_minlength, out);
}
}
} // namespace phi
......
......@@ -14,13 +14,16 @@
from __future__ import print_function
import os
import unittest
import tempfile
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest
import paddle.inference as paddle_infer
paddle.enable_static()
......@@ -206,5 +209,66 @@ class TestCase5(TestBincountOp):
self.Out = np.bincount(self.np_input, minlength=self.minlength)
class TestTensorMinlength(unittest.TestCase):
def setUp(self):
paddle.disable_static()
paddle.seed(2022)
self.temp_dir = tempfile.TemporaryDirectory()
self.save_path = os.path.join(self.temp_dir.name,
'tensor_minlength_bincount')
self.place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
def test_dygraph(self):
paddle.disable_static()
x = np.random.randint(0, 10, [20])
minlength = 2
np_out = np.bincount(x, minlength=minlength)
pd_out = paddle.bincount(paddle.to_tensor(x),
minlength=paddle.to_tensor([2], dtype='int32'))
np.testing.assert_allclose(np_out, pd_out.numpy())
def test_static_and_infer(self):
paddle.enable_static()
np_x = np.random.randn(100).astype('float32')
main_prog = paddle.static.Program()
starup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, starup_prog):
# run static
x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype)
linear = paddle.nn.Linear(np_x.shape[0], np_x.shape[0])
linear_out = linear(x)
relu_out = paddle.nn.functional.relu(linear_out)
minlength = paddle.full([1], 3, dtype='int32')
out = paddle.bincount(paddle.cast(relu_out, 'int32'),
minlength=minlength)
exe = paddle.static.Executor(self.place)
exe.run(starup_prog)
static_out = exe.run(feed={'x': np_x}, fetch_list=[out])
# run infer
paddle.static.save_inference_model(self.save_path, [x], [out], exe)
config = paddle_infer.Config(self.save_path + '.pdmodel',
self.save_path + '.pdiparams')
if paddle.is_compiled_with_cuda():
config.enable_use_gpu(100, 0)
else:
config.disable_gpu()
predictor = paddle_infer.create_predictor(config)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
fake_input = np_x
input_handle.reshape(np_x.shape)
input_handle.copy_from_cpu(fake_input)
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
infer_out = output_handle.copy_to_cpu()
np.testing.assert_allclose(static_out[0], infer_out)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册