From 991ec7d3097e07e6c8cfb92be9c480a9583c5657 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 23 Sep 2022 20:25:04 +0800 Subject: [PATCH] [Phi] support bincount yaml and _C_ops.bincount under eager (#46443) --- paddle/phi/api/yaml/legacy_ops.yaml | 9 +++++++++ .../paddle/fluid/tests/unittests/test_bincount_op.py | 12 ++++++++++-- python/paddle/tensor/linalg.py | 4 +++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index d6e9218f30..7f25ba6be4 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -3018,3 +3018,12 @@ func: unpool3d data_type: x backward: unpool3d_grad + +- op: bincount + args: (Tensor x, Tensor weights, Scalar minlength) + output: Tensor(out) + infer_meta: + func: BincountInferMeta + kernel: + func: bincount + optional: weights diff --git a/python/paddle/fluid/tests/unittests/test_bincount_op.py b/python/paddle/fluid/tests/unittests/test_bincount_op.py index ca0113fe7f..c85c9d4b6a 100644 --- a/python/paddle/fluid/tests/unittests/test_bincount_op.py +++ b/python/paddle/fluid/tests/unittests/test_bincount_op.py @@ -24,6 +24,7 @@ import paddle.fluid.core as core from paddle.fluid import Program, program_guard from op_test import OpTest import paddle.inference as paddle_infer +from paddle.fluid.framework import in_dygraph_mode paddle.enable_static() @@ -101,8 +102,15 @@ class TestBincountOpError(unittest.TestCase): input_value = paddle.to_tensor([1, 2, 3, 4, 5]) paddle.bincount(input_value, minlength=-1) - with self.assertRaises(IndexError): - self.run_network(net_func) + with fluid.dygraph.guard(): + if in_dygraph_mode(): + # InvalidArgument for phi BincountKernel + with self.assertRaises(ValueError): + self.run_network(net_func) + else: + # OutOfRange for EqualGreaterThanChecker + with self.assertRaises(IndexError): + self.run_network(net_func) def test_input_type_errors(self): """Test input tensor should only contain non-negative ints.""" diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index b0ef5820a1..7c4644bc40 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1651,7 +1651,9 @@ def bincount(x, weights=None, minlength=0, name=None): if x.dtype not in [paddle.int32, paddle.int64]: raise TypeError("Elements in Input(x) should all be integers") - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.bincount(x, weights, minlength) + elif _in_legacy_dygraph(): return _legacy_C_ops.bincount(x, weights, "minlength", minlength) helper = LayerHelper('bincount', **locals()) -- GitLab