diff --git a/paddle/fluid/operators/bincount_op.cc b/paddle/fluid/operators/bincount_op.cc index 140f98916ea18d427502f9a47be1481eb5e51d35..d52de7ace64abcb4eb88ca275ba03ccd10d0620a 100644 --- a/paddle/fluid/operators/bincount_op.cc +++ b/paddle/fluid/operators/bincount_op.cc @@ -50,7 +50,8 @@ class BincountOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor) The output tensor of Bincount op,"); AddAttr("minlength", "(int) The minimal numbers of bins") .SetDefault(0) - .EqualGreaterThan(0); + .EqualGreaterThan(0) + .SupportTensor(); AddComment(R"DOC( Bincount Operator. Computes frequency of each value in the input tensor. diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 7631457dc4a4a45b7b6823c4f3c848e61a053f27..ad8897bb4c07848cf285199cb20e9b02395b7b09 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -210,17 +210,10 @@ void BCELossInferMeta(const MetaTensor& input, void BincountInferMeta(const MetaTensor& x, const MetaTensor& weights, - int minlength, + const Scalar& minlength, MetaTensor* out) { auto input_dim = x.dims(); - PADDLE_ENFORCE_GE(minlength, - 0, - phi::errors::InvalidArgument( - "The minlength should be greater than or equal to 0." - "But received minlength is %d", - minlength)); - PADDLE_ENFORCE_EQ( input_dim.size(), 1, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index d28d15f0829e79dcbde04ff3ccbba9fbdf5aadf5..10430c289e41d2a7d6d2b1d50e5f24fcf38a830a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -57,7 +57,7 @@ void BCELossInferMeta(const MetaTensor& input, void BincountInferMeta(const MetaTensor& x, const MetaTensor& weights, - int minlength, + const Scalar& minlength, MetaTensor* out); void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); diff --git a/paddle/phi/kernels/bincount_kernel.h b/paddle/phi/kernels/bincount_kernel.h index e110b6e014b4de11babb70cbfb8fb05bf165a6c7..7b72a1fabd8f52e7a1a88b1027343d42535fcfd5 100644 --- a/paddle/phi/kernels/bincount_kernel.h +++ b/paddle/phi/kernels/bincount_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -22,7 +23,7 @@ template void BincountKernel(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, - int minlength, + const Scalar& minlength, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/bincount_kernel.cc b/paddle/phi/kernels/cpu/bincount_kernel.cc index 8163953c1e00e129cb4443a56c1d5906b2d246fd..97f28f541df09e5c4b5ff8b6f0cc76c45ee9d646 100644 --- a/paddle/phi/kernels/cpu/bincount_kernel.cc +++ b/paddle/phi/kernels/cpu/bincount_kernel.cc @@ -86,12 +86,20 @@ template void BincountKernel(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, - int minlength, + const Scalar& minlength, DenseTensor* out) { + int int_minlength = minlength.to(); + PADDLE_ENFORCE_GE(int_minlength, + 0, + phi::errors::InvalidArgument( + "The minlength should be greater than or equal to 0." + "But received minlength is %d", + int_minlength)); + if (x.dtype() == DataType::INT32) { - BincountInner(dev_ctx, x, weights, minlength, out); + BincountInner(dev_ctx, x, weights, int_minlength, out); } else if (x.dtype() == DataType::INT64) { - BincountInner(dev_ctx, x, weights, minlength, out); + BincountInner(dev_ctx, x, weights, int_minlength, out); } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/bincount_kernel.cu b/paddle/phi/kernels/gpu/bincount_kernel.cu index d6073193a15056dc22c72dff093852c186fcbd06..3b1e41d92e6b676d0f092236946ea7e1314a1453 100644 --- a/paddle/phi/kernels/gpu/bincount_kernel.cu +++ b/paddle/phi/kernels/gpu/bincount_kernel.cu @@ -138,12 +138,21 @@ template void BincountKernel(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, - int minlength, + const Scalar& minlength, DenseTensor* out) { + int int_minlength = minlength.to(); + PADDLE_ENFORCE_GE(int_minlength, + 0, + phi::errors::InvalidArgument( + "The minlength should be greater than or equal to 0." + "But received minlength is %d", + int_minlength)); + if (x.dtype() == DataType::INT32) { - BincountCUDAInner(dev_ctx, x, weights, minlength, out); + BincountCUDAInner(dev_ctx, x, weights, int_minlength, out); } else if (x.dtype() == DataType::INT64) { - BincountCUDAInner(dev_ctx, x, weights, minlength, out); + BincountCUDAInner( + dev_ctx, x, weights, int_minlength, out); } } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_bincount_op.py b/python/paddle/fluid/tests/unittests/test_bincount_op.py index 2b99c92191150431a053a458cffac7bbd6aa19c0..ca0113fe7fcc0a67334c5596b6561581a3121a4c 100644 --- a/python/paddle/fluid/tests/unittests/test_bincount_op.py +++ b/python/paddle/fluid/tests/unittests/test_bincount_op.py @@ -14,13 +14,16 @@ from __future__ import print_function +import os import unittest +import tempfile import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard from op_test import OpTest +import paddle.inference as paddle_infer paddle.enable_static() @@ -206,5 +209,66 @@ class TestCase5(TestBincountOp): self.Out = np.bincount(self.np_input, minlength=self.minlength) +class TestTensorMinlength(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + paddle.seed(2022) + self.temp_dir = tempfile.TemporaryDirectory() + self.save_path = os.path.join(self.temp_dir.name, + 'tensor_minlength_bincount') + self.place = paddle.CUDAPlace( + 0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + + def test_dygraph(self): + paddle.disable_static() + x = np.random.randint(0, 10, [20]) + minlength = 2 + np_out = np.bincount(x, minlength=minlength) + pd_out = paddle.bincount(paddle.to_tensor(x), + minlength=paddle.to_tensor([2], dtype='int32')) + np.testing.assert_allclose(np_out, pd_out.numpy()) + + def test_static_and_infer(self): + paddle.enable_static() + np_x = np.random.randn(100).astype('float32') + main_prog = paddle.static.Program() + starup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, starup_prog): + # run static + x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype) + linear = paddle.nn.Linear(np_x.shape[0], np_x.shape[0]) + linear_out = linear(x) + relu_out = paddle.nn.functional.relu(linear_out) + minlength = paddle.full([1], 3, dtype='int32') + out = paddle.bincount(paddle.cast(relu_out, 'int32'), + minlength=minlength) + + exe = paddle.static.Executor(self.place) + exe.run(starup_prog) + static_out = exe.run(feed={'x': np_x}, fetch_list=[out]) + + # run infer + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + config = paddle_infer.Config(self.save_path + '.pdmodel', + self.save_path + '.pdiparams') + if paddle.is_compiled_with_cuda(): + config.enable_use_gpu(100, 0) + else: + config.disable_gpu() + + predictor = paddle_infer.create_predictor(config) + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + fake_input = np_x + input_handle.reshape(np_x.shape) + input_handle.copy_from_cpu(fake_input) + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + infer_out = output_handle.copy_to_cpu() + np.testing.assert_allclose(static_out[0], infer_out) + + if __name__ == "__main__": unittest.main()