提交 5431929e 编写于 作者: M Megvii Engine Team

feat(functional): let advance indexing support empty tensor and add more tests

GitOrigin-RevId: 49e1492934813caf4e491a901610b95439bac236
上级 703b783c
...@@ -176,6 +176,8 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): ...@@ -176,6 +176,8 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
def is_bool_list(x): def is_bool_list(x):
if not isinstance(x, list): if not isinstance(x, list):
return False return False
if len(x) == 0:
return False
for i in x: for i in x:
if not isinstance(i, bool): if not isinstance(i, bool):
return False return False
...@@ -246,17 +248,6 @@ def getitem(tensor, index): ...@@ -246,17 +248,6 @@ def getitem(tensor, index):
if len(try_result) == 2: if len(try_result) == 2:
return try_result[0] return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors:
if v.shape is None:
break
if isinstance(v.shape, v.__class__):
break
if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
return empty_tensor
if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)
else: else:
......
...@@ -610,6 +610,25 @@ def test_subtensor_on_empty_tensor(symbolic): ...@@ -610,6 +610,25 @@ def test_subtensor_on_empty_tensor(symbolic):
run_test(lambda x: x[100:200, 300:400, 500:600]) run_test(lambda x: x[100:200, 300:400, 500:600])
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_indexingMultiAxisVec_on_empty_tensor(symbolic):
np_x = np.array([], dtype=np.float32).reshape(10, 10, 0)
mge_x = megengine.tensor(np_x)
def run_test(fn):
out_ref = fn(np_x)
if symbolic is not None:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(mge_x)
np.testing.assert_equal(out.numpy(), out_ref)
run_test(lambda x: x[[1, 2, 3]])
run_test(lambda x: x[[1, 2, 3], [4, 5, 6]])
run_test(lambda x: x[[]])
run_test(lambda x: x[[], [], []])
@pytest.mark.parametrize("symbolic", [True, False, None]) @pytest.mark.parametrize("symbolic", [True, False, None])
def test_setsubtensor_on_empty_tensor(symbolic): def test_setsubtensor_on_empty_tensor(symbolic):
def run_test(inp_shp, fn): def run_test(inp_shp, fn):
...@@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic): ...@@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic):
run_test((10, 10, 10), test4) run_test((10, 10, 10), test4)
run_test((10, 10, 10), test5) run_test((10, 10, 10), test5)
run_test((10, 10, 10), test6) run_test((10, 10, 10), test6)
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic):
def run_test(inp_shp, fn):
np_x = np.random.randn(*inp_shp).astype(np.float32)
mge_x = megengine.tensor(np_x)
out_ref = fn(np_x)
if symbolic is not None:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(mge_x)
np.testing.assert_equal(out.numpy(), out_ref)
def test1(x):
x[[1, 2, 3]] = x[[1, 2, 3]]
return x
def test2(x):
x[[1, 2, 3], [1, 2, 3]] = x[[1, 2, 3], [1, 2, 3]]
return x
def test3(x):
x[[]] = x[[]]
return x
def test4(x):
x[[], [], []] = x[[], [], []]
return x
run_test((10, 10, 0), test1)
run_test((10, 10, 0), test2)
run_test((10, 10, 0), test3)
run_test((10, 10, 0), test4)
run_test((10, 10, 10), test3)
run_test((10, 10, 10), test4)
...@@ -860,8 +860,8 @@ def test_condtake(): ...@@ -860,8 +860,8 @@ def test_condtake():
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
# @pytest.mark.parametrize("is_symbolic", [None, False, True]) @pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_condtake(is_symbolic=None): def test_condtake(is_symbolic):
shapes = [ shapes = [
(3, 3, 3), (3, 3, 3),
(0,), (0,),
......
...@@ -292,8 +292,6 @@ cg::OperatorNodeBase::NodeProp* ...@@ -292,8 +292,6 @@ cg::OperatorNodeBase::NodeProp*
IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const { IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop(); auto prop = Super::do_make_node_prop();
using DT = NodeProp::DepType; using DT = NodeProp::DepType;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY); prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY);
for (auto i: m_input2idxonly_axis_indexer) { for (auto i: m_input2idxonly_axis_indexer) {
if (i) { if (i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册