未验证 提交 991ec7d3 编写于 作者: W Weilong Wu 提交者: GitHub

[Phi] support bincount yaml and _C_ops.bincount under eager (#46443)

上级 c4de1277
...@@ -3018,3 +3018,12 @@ ...@@ -3018,3 +3018,12 @@
func: unpool3d func: unpool3d
data_type: x data_type: x
backward: unpool3d_grad backward: unpool3d_grad
- op: bincount
args: (Tensor x, Tensor weights, Scalar minlength)
output: Tensor(out)
infer_meta:
func: BincountInferMeta
kernel:
func: bincount
optional: weights
...@@ -24,6 +24,7 @@ import paddle.fluid.core as core ...@@ -24,6 +24,7 @@ import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from op_test import OpTest from op_test import OpTest
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from paddle.fluid.framework import in_dygraph_mode
paddle.enable_static() paddle.enable_static()
...@@ -101,8 +102,15 @@ class TestBincountOpError(unittest.TestCase): ...@@ -101,8 +102,15 @@ class TestBincountOpError(unittest.TestCase):
input_value = paddle.to_tensor([1, 2, 3, 4, 5]) input_value = paddle.to_tensor([1, 2, 3, 4, 5])
paddle.bincount(input_value, minlength=-1) paddle.bincount(input_value, minlength=-1)
with self.assertRaises(IndexError): with fluid.dygraph.guard():
self.run_network(net_func) if in_dygraph_mode():
# InvalidArgument for phi BincountKernel
with self.assertRaises(ValueError):
self.run_network(net_func)
else:
# OutOfRange for EqualGreaterThanChecker
with self.assertRaises(IndexError):
self.run_network(net_func)
def test_input_type_errors(self): def test_input_type_errors(self):
"""Test input tensor should only contain non-negative ints.""" """Test input tensor should only contain non-negative ints."""
......
...@@ -1651,7 +1651,9 @@ def bincount(x, weights=None, minlength=0, name=None): ...@@ -1651,7 +1651,9 @@ def bincount(x, weights=None, minlength=0, name=None):
if x.dtype not in [paddle.int32, paddle.int64]: if x.dtype not in [paddle.int32, paddle.int64]:
raise TypeError("Elements in Input(x) should all be integers") raise TypeError("Elements in Input(x) should all be integers")
if _non_static_mode(): if in_dygraph_mode():
return _C_ops.bincount(x, weights, minlength)
elif _in_legacy_dygraph():
return _legacy_C_ops.bincount(x, weights, "minlength", minlength) return _legacy_C_ops.bincount(x, weights, "minlength", minlength)
helper = LayerHelper('bincount', **locals()) helper = LayerHelper('bincount', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册