未验证 提交 991ec7d3 编写于 作者: W Weilong Wu 提交者: GitHub

[Phi] support bincount yaml and _C_ops.bincount under eager (#46443)

上级 c4de1277
......@@ -3018,3 +3018,12 @@
func: unpool3d
data_type: x
backward: unpool3d_grad
- op: bincount
args: (Tensor x, Tensor weights, Scalar minlength)
output: Tensor(out)
infer_meta:
func: BincountInferMeta
kernel:
func: bincount
optional: weights
......@@ -24,6 +24,7 @@ import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest
import paddle.inference as paddle_infer
from paddle.fluid.framework import in_dygraph_mode
paddle.enable_static()
......@@ -101,8 +102,15 @@ class TestBincountOpError(unittest.TestCase):
input_value = paddle.to_tensor([1, 2, 3, 4, 5])
paddle.bincount(input_value, minlength=-1)
with self.assertRaises(IndexError):
self.run_network(net_func)
with fluid.dygraph.guard():
if in_dygraph_mode():
# InvalidArgument for phi BincountKernel
with self.assertRaises(ValueError):
self.run_network(net_func)
else:
# OutOfRange for EqualGreaterThanChecker
with self.assertRaises(IndexError):
self.run_network(net_func)
def test_input_type_errors(self):
"""Test input tensor should only contain non-negative ints."""
......
......@@ -1651,7 +1651,9 @@ def bincount(x, weights=None, minlength=0, name=None):
if x.dtype not in [paddle.int32, paddle.int64]:
raise TypeError("Elements in Input(x) should all be integers")
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.bincount(x, weights, minlength)
elif _in_legacy_dygraph():
return _legacy_C_ops.bincount(x, weights, "minlength", minlength)
helper = LayerHelper('bincount', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册